def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode): """Creates a variable handle with information to do shape inference.""" container = ops.get_default_graph()._container # pylint: disable=protected-access if container is None: container = "" handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, shared_name=shared_name, name=name, container=container) if graph_mode: return handle # We do not want two distinct ResourceVariable objects for the same # underlying resource in the runtime. # When in eager mode, explicitly ensure so here. When in graph mode, it's # ensured by always generating different variable names. exists = gen_resource_variable_ops.var_is_initialized_op(handle) if exists: raise ValueError("variable object with name '%s' already created. Use " "get_variable() if reuse is desired." % shared_name) with context.graph_mode(), ops.Graph().as_default() as graph: h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, shared_name=shared_name, name=name, container=container) # Tensor._handle_data contains information for the shape-inference code to # know the shape and dtype of the variable pointed to by a handle. Since # shape inference doesn't run in eager mode we copy this data here for when # the handle is captured by an eager mode function. # pylint: disable=protected-access if ops._USE_C_SHAPES: handle._handle_data = get_resource_handle_data(h) else: if h._handle_data is None: ops.set_shape_and_handle_data_for_outputs(h.op) handle._handle_data = h._handle_data # pylint: enable=protected-access # Clean up our reference cycles to avoid making the garbage collector run. # pylint: disable=protected-access # OrderedDict, constructed on Graph creation, makes a simple reference loop # and hides it in an __attribute in some Python versions. We don't need to # throw an error if we can't find it, but if we do find it we can break the # loop to avoid creating work for the garbage collector. problematic_cycle = graph._functions.__dict__.get("_OrderedDict__root", None) # pylint: enable=protected-access if problematic_cycle: try: del problematic_cycle[0][:] except TypeError: # This is probably not one of the problematic Python versions. Continue # with the rest of our cleanup. pass # Now clean up our own reference cycles by clearing all of the attributes for # the Graph and op we created. h.__dict__ = {} graph.__dict__ = {} return handle
def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode): """Creates a variable handle with information to do shape inference.""" container = ops.get_default_graph()._container # pylint: disable=protected-access if container is None: container = "" handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, shared_name=shared_name, name=name, container=container) if graph_mode: return handle # We do not want two distinct ResourceVariable objects for the same # underlying resource in the runtime. # When in eager mode, explicitly ensure so here. When in graph mode, it's # ensured by always generating different variable names. exists = gen_resource_variable_ops.var_is_initialized_op(handle) if exists: raise ValueError("variable object with name '%s' already created. Use " "get_variable() if reuse is desired." % shared_name) with context.graph_mode(), ops.Graph().as_default(): h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, shared_name=shared_name, name=name, container=container) # Tensor._handle_data contains information for the shape-inference code to # know the shape and dtype of the variable pointed to by a handle. Since # shape inference doesn't run in eager mode we copy this data here for when # the handle is captured by an eager mode function. handle._handle_data = h._handle_data # pylint: disable=protected-access return handle
def _eager_safe_variable_handle(shape, dtype, shared_name, name, container=None): """Creates a variable handle with information to do shape inference.""" handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, shared_name=shared_name, name=name, container=container) if context.in_graph_mode(): return handle with context.graph_mode(), ops.Graph().as_default(): h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, shared_name=shared_name, name=name, container=container) # Tensor._handle_data contains information for the shape-inference code to # know the shape and dtype of the variable pointed to by a handle. Since # shape inference doesn't run in eager mode we copy this data here for when # the handle is captured by an eager mode function. handle._handle_data = h._handle_data # pylint: disable=protected-access return handle
def test_shared_variable(self): x = gen_resource_variable_ops.var_handle_op(dtype=tf.float32, shape=(1, 2), shared_name="variable_1") gen_resource_variable_ops.assign_variable_op(x, tf.constant([[1.0, 2.0]])) y = gen_resource_variable_ops.var_handle_op(dtype=tf.float32, shape=(1, 2), shared_name="variable_1") gen_resource_variable_ops.assign_variable_op(y, tf.constant([[2.0, 3.0]])) read_x = gen_resource_variable_ops.read_variable_op(x, dtype=tf.float32) read_y = gen_resource_variable_ops.read_variable_op(y, dtype=tf.float32) self.assertTrue(tensor_equal(read_x, read_y)) x = gen_resource_variable_ops.var_handle_op( dtype=tf.float32, shape=(1, 2), shared_name=context.shared_name()) gen_resource_variable_ops.assign_variable_op(x, tf.constant([[1.0, 2.0]])) y = gen_resource_variable_ops.var_handle_op( dtype=tf.float32, shape=(1, 2), shared_name=context.shared_name()) gen_resource_variable_ops.assign_variable_op(y, tf.constant([[2.0, 3.0]])) read_x = gen_resource_variable_ops.read_variable_op(x, dtype=tf.float32) read_y = gen_resource_variable_ops.read_variable_op(y, dtype=tf.float32) self.assertFalse(tensor_equal(read_x, read_y))
def _create_resource(self): return gen_resource_variable_ops.var_handle_op( shape=[], dtype=dtypes.float32, shared_name=context.anonymous_name(), name="StateVar", container="")
def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode, container=None): """Creates a variable handle with information to do shape inference.""" handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, shared_name=shared_name, name=name, container=container) if graph_mode: return handle with context.graph_mode(), ops.Graph().as_default(): h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, shared_name=shared_name, name=name, container=container) # Tensor._handle_data contains information for the shape-inference code to # know the shape and dtype of the variable pointed to by a handle. Since # shape inference doesn't run in eager mode we copy this data here for when # the handle is captured by an eager mode function. handle._handle_data = h._handle_data # pylint: disable=protected-access return handle
def _init_from_args(self, initial_value=None, trainable=True, collections=None, validate_shape=True, caching_device=None, name=None, dtype=None, constraint=None): """Creates a variable. Args: initial_value: A `Tensor`, or Python object convertible to a `Tensor`, which is the initial value for the Variable. The initial value must have a shape specified unless `validate_shape` is set to False. Can also be a callable with no argument that returns the initial value when called. (Note that initializer functions from init_ops.py must first be bound to a shape before being used here.) trainable: If `True`, the default, also adds the variable to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default list of variables to use by the `Optimizer` classes. collections: List of graph collections keys. The new variable is added to these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. validate_shape: Ignored. Provided for compatibility with tf.Variable. caching_device: Optional device string or function describing where the Variable should be cached for reading. Defaults to the Variable's device. If not `None`, caches on another device. Typical use is to cache on the device where the Ops using the Variable reside, to deduplicate copying through `Switch` and other conditional statements. name: Optional name for the variable. Defaults to `'Variable'` and gets uniquified automatically. dtype: If set, initial_value will be converted to the given type. If None, either the datatype will be kept (if initial_value is a Tensor) or float32 will be used (if it is a Python object convertible to a Tensor). constraint: An optional projection function to be applied to the variable after being updated by an `Optimizer` (e.g. used to implement norm constraints or value constraints for layer weights). The function must take as input the unprojected Tensor representing the value of the variable and return the Tensor for the projected value (which must have the same shape). Constraints are not safe to use when doing asynchronous distributed training. Raises: ValueError: If the initial value is not specified, or does not have a shape and `validate_shape` is `True`. """ if initial_value is None: raise ValueError("initial_value must be specified.") init_from_fn = callable(initial_value) if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] if not isinstance(collections, (list, tuple, set)): raise ValueError( "collections argument to Variable constructor must be a list, tuple, " "or set. Got %s of type %s" % (collections, type(collections))) if constraint is not None and not callable(constraint): raise ValueError("The `constraint` argument must be a callable.") if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ ops.GraphKeys.TRAINABLE_VARIABLES ] self._save_slice_info = None in_graph_mode = context.in_graph_mode() with ops.control_dependencies(None): with ops.name_scope( name, "Variable", [] if init_from_fn else [initial_value]) as name: # pylint: disable=protected-access handle_name = ops._name_from_scope_name(name) if init_from_fn: # Use attr_scope and device(None) to simulate the behavior of # colocate_with when the variable we want to colocate with doesn't # yet exist. if in_graph_mode: attr = attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue( s=[compat.as_bytes("loc:@%s" % handle_name)])) with ops.get_default_graph()._attr_scope( {"_class": attr}): with ops.name_scope("Initializer"), ops.device( None): initial_value = ops.convert_to_tensor( initial_value(), name="initial_value", dtype=dtype) self._handle = gen_resource_variable_ops.var_handle_op( shape=initial_value.get_shape(), dtype=initial_value.dtype.base_dtype, shared_name=handle_name, name=name) else: initial_value = initial_value() self._handle = gen_resource_variable_ops.var_handle_op( shape=initial_value.get_shape(), dtype=initial_value.dtype.base_dtype, shared_name=handle_name, name=name, container="") # pylint: enable=protected-access # Or get the initial value from a Tensor or Python object. else: with ops.name_scope("Initializer"): initial_value = ops.convert_to_tensor( initial_value, name="initial_value", dtype=dtype) # pylint: disable=protected-access if (in_graph_mode and initial_value is not None and initial_value.op._get_control_flow_context() is not None): raise ValueError( "Initializer for variable %s is from inside a control-flow " "construct, such as a loop or conditional. When creating a " "variable inside a loop or conditional, use a lambda as the " "initializer." % name) # pylint: enable=protected-access self._handle = gen_resource_variable_ops.var_handle_op( shape=initial_value.get_shape(), dtype=initial_value.dtype.base_dtype, shared_name=handle_name, name=name, container="") self._initial_value = initial_value if in_graph_mode else None self._handle_name = handle_name + ":0" self._dtype = initial_value.dtype.base_dtype self._constraint = constraint if in_graph_mode: with ops.name_scope("IsInitialized"): self._is_initialized_op = ( gen_resource_variable_ops.var_is_initialized_op( self._handle)) if initial_value is not None: with ops.name_scope("Assign") as n, ops.colocate_with( self._handle): self._initializer_op = ( gen_resource_variable_ops.assign_variable_op( self._handle, self._build_initializer_expr( initial_value), name=n)) with ops.name_scope("Read"), ops.colocate_with( self._handle): # Manually assign reads to the handle's device to avoid log # messages. with ops.device(self._handle.device): value = gen_resource_variable_ops.read_variable_op( self._handle, dtype=self._dtype) self._graph_element = value if caching_device is not None: # Variables may be created in a tf.device() or ops.colocate_with() # context. At the same time, users would expect caching device to # be independent of this context, and/or would not expect the # current device context to be merged with the caching device # spec. Therefore we reset the colocation stack before creating # the cached value. Note that resetting the colocation stack will # also reset the device stack. with ops.colocate_with(None, ignore_existing=True): with ops.device(caching_device): self._cached_value = array_ops.identity( value) else: self._cached_value = None else: gen_resource_variable_ops.assign_variable_op( self._handle, initial_value) self._is_initialized_op = None self._initializer_op = None self._graph_element = None if caching_device: with ops.device(caching_device): self._cached_value = gen_resource_variable_ops.read_variable_op( self._handle, dtype=self._dtype) else: self._cached_value = None ops.add_to_collections(collections, self)
def __init__(self, initial_value=None, name=None, trainable=True, collections=None, dtype=None, shape=None): """Creates a variable. Args: initial_value: A `Tensor` or Python object convertible to a `Tensor` representing the initial value of this variable. name: The name of this variable. Automatically uniquified. trainable: Whether the global read of this variable will be used for training. collections: Additional collections to which the `read` operation for this variable is to be added. Defaults to []. dtype: The type of this variable. Can be omitted if it can be deduced from the initial_value. If different from the type of the initial value it will be cast to this type. shape: The shape of this variable. Only specify if there is no initial value but shape inference is desired. """ if initial_value is not None: initial_value = ops.convert_to_tensor(initial_value) if dtype is None: assert initial_value is not None, ("Trying to create a resource variable " "with no dtype or initial value. At" " least one of these must be set.") dtype = initial_value.dtype elif initial_value is not None: initial_value = math_ops.cast(initial_value, dtype) if shape is None: if initial_value is not None: shape = initial_value.get_shape().as_proto() else: shape = tensor_shape.unknown_shape() else: shape = tensor_shape.as_shape(shape) self._dtype = dtype with ops.name_scope(name, "Variable", [initial_value]) as name: self._handle = gen_resource_variable_ops.var_handle_op(shared_name=name, name=name, dtype=dtype, shape=shape) with ops.name_scope("IsInitialized"): self._is_initialized_op = ( gen_resource_variable_ops.var_is_initialized_op(self._handle)) if initial_value is not None: with ops.name_scope("Create"): self._initialize_op = gen_resource_variable_ops.create_variable_op( self._handle, initial_value) resources.register_resource(self._handle, self._initialize_op, self._is_initialized_op) with ops.name_scope("Read"): self._value = gen_resource_variable_ops.read_variable_op( self._handle, dtype=self._dtype) _register_variable_read( self._value, trainable=trainable, collections=collections)
def _init_from_args(self, initial_value=None, trainable=True, collections=None, validate_shape=True, caching_device=None, name=None, dtype=None): """Creates a variable. Args: initial_value: A `Tensor`, or Python object convertible to a `Tensor`, which is the initial value for the Variable. The initial value must have a shape specified unless `validate_shape` is set to False. Can also be a callable with no argument that returns the initial value when called. (Note that initializer functions from init_ops.py must first be bound to a shape before being used here.) trainable: If `True`, the default, also adds the variable to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default list of variables to use by the `Optimizer` classes. collections: List of graph collections keys. The new variable is added to these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. validate_shape: Ignored. Provided for compatibility with tf.Variable. caching_device: Optional device string or function describing where the Variable should be cached for reading. Defaults to the Variable's device. If not `None`, caches on another device. Typical use is to cache on the device where the Ops using the Variable reside, to deduplicate copying through `Switch` and other conditional statements. name: Optional name for the variable. Defaults to `'Variable'` and gets uniquified automatically. dtype: If set, initial_value will be converted to the given type. If None, either the datatype will be kept (if initial_value is a Tensor) or float32 will be used (if it is a Python object convertible to a Tensor). Raises: ValueError: If the initial value is not specified, or does not have a shape and `validate_shape` is `True`. """ if initial_value is None: raise ValueError("initial_value must be specified.") init_from_fn = callable(initial_value) if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] if not isinstance(collections, (list, tuple, set)): raise ValueError( "collections argument to Variable constructor must be a list, tuple, " "or set. Got %s of type %s" % (collections, type(collections))) if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ ops.GraphKeys.TRAINABLE_VARIABLES ] self._save_slice_info = None with ops.control_dependencies(None): with ops.name_scope( name, "Variable", [] if init_from_fn else [initial_value]) as name: if init_from_fn: # Use attr_scope and device(None) to simulate the behavior of # colocate_with when the variable we want to colocate with doesn't # yet exist. # pylint: disable=protected-access true_name = ops._name_from_scope_name(name) attr = attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue( s=[compat.as_bytes("loc:@%s" % true_name)])) # pylint: disable=protected-access with ops.get_default_graph()._attr_scope({"_class": attr}): with ops.name_scope("Initializer"), ops.device(None): self._initial_value = ops.convert_to_tensor( initial_value(), name="initial_value", dtype=dtype) self._handle = gen_resource_variable_ops.var_handle_op( shape=self._initial_value.get_shape(), dtype=self._initial_value.dtype.base_dtype, shared_name=name, name=name) # Or get the initial value from a Tensor or Python object. else: self._initial_value = ops.convert_to_tensor( initial_value, name="initial_value", dtype=dtype) self._handle = gen_resource_variable_ops.var_handle_op( shape=self._initial_value.get_shape(), dtype=self._initial_value.dtype.base_dtype, shared_name=name, name=name) self._dtype = self._initial_value.dtype.base_dtype with ops.name_scope("IsInitialized"): self._is_initialized_op = ( gen_resource_variable_ops.var_is_initialized_op( self._handle)) if initial_value is not None: with ops.name_scope("Assign") as n, ops.colocate_with( self._handle): self._initialize_op = gen_resource_variable_ops.assign_variable_op( self._handle, self._initial_value, name=n) with ops.name_scope("Read"), ops.colocate_with(self._handle): value = gen_resource_variable_ops.read_variable_op( self._handle, dtype=self._dtype) self._graph_element = value if caching_device is not None: with ops.device(caching_device): self._cached_value = array_ops.identity(value) else: self._cached_value = None ops.add_to_collections(collections, self)
def _init_from_args(self, initial_value=None, trainable=True, collections=None, validate_shape=True, caching_device=None, name=None, dtype=None): """Creates a variable. Args: initial_value: A `Tensor`, or Python object convertible to a `Tensor`, which is the initial value for the Variable. The initial value must have a shape specified unless `validate_shape` is set to False. Can also be a callable with no argument that returns the initial value when called. (Note that initializer functions from init_ops.py must first be bound to a shape before being used here.) trainable: If `True`, the default, also adds the variable to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default list of variables to use by the `Optimizer` classes. collections: List of graph collections keys. The new variable is added to these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. validate_shape: Ignored. Provided for compatibility with tf.Variable. caching_device: Optional device string or function describing where the Variable should be cached for reading. Defaults to the Variable's device. If not `None`, caches on another device. Typical use is to cache on the device where the Ops using the Variable reside, to deduplicate copying through `Switch` and other conditional statements. name: Optional name for the variable. Defaults to `'Variable'` and gets uniquified automatically. dtype: If set, initial_value will be converted to the given type. If None, either the datatype will be kept (if initial_value is a Tensor) or float32 will be used (if it is a Python object convertible to a Tensor). Raises: ValueError: If the initial value is not specified, or does not have a shape and `validate_shape` is `True`. """ if initial_value is None: raise ValueError("initial_value must be specified.") init_from_fn = callable(initial_value) if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] if not isinstance(collections, (list, tuple, set)): raise ValueError( "collections argument to Variable constructor must be a list, tuple, " "or set. Got %s of type %s" % (collections, type(collections))) if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] self._save_slice_info = None with ops.control_dependencies(None): with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as name: # pylint: disable=protected-access true_name = ops._name_from_scope_name(name) if init_from_fn: # Use attr_scope and device(None) to simulate the behavior of # colocate_with when the variable we want to colocate with doesn't # yet exist. attr = attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue( s=[compat.as_bytes("loc:@%s" % true_name)])) with ops.get_default_graph()._attr_scope({"_class": attr}): with ops.name_scope("Initializer"), ops.device(None): self._initial_value = ops.convert_to_tensor( initial_value(), name="initial_value", dtype=dtype) self._handle = gen_resource_variable_ops.var_handle_op( shape=self._initial_value.get_shape(), dtype=self._initial_value.dtype.base_dtype, shared_name=true_name, name=name) # pylint: enable=protected-access # Or get the initial value from a Tensor or Python object. else: self._initial_value = ops.convert_to_tensor( initial_value, name="initial_value", dtype=dtype) self._handle = gen_resource_variable_ops.var_handle_op( shape=self._initial_value.get_shape(), dtype=self._initial_value.dtype.base_dtype, shared_name=true_name, name=name) self._dtype = self._initial_value.dtype.base_dtype with ops.name_scope("IsInitialized"): self._is_initialized_op = ( gen_resource_variable_ops.var_is_initialized_op(self._handle)) if initial_value is not None: with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): self._initialize_op = gen_resource_variable_ops.assign_variable_op( self._handle, self._initial_value, name=n) with ops.name_scope("Read"), ops.colocate_with(self._handle): # Manually assign reads to the handle's device to avoid log messages. with ops.device(self._handle.device): value = gen_resource_variable_ops.read_variable_op( self._handle, dtype=self._dtype) self._graph_element = value if caching_device is not None: # Variables may be created in a tf.device() or ops.colocate_with() # context. At the same time, users would expect caching device to be # independent of this context, and/or would not expect the current # device context to be merged with the caching device spec. # Therefore we reset the colocation stack before creating the cached # value. Note that resetting the colocation stack will also reset # the device stack. with ops.colocate_with(None, ignore_existing=True): with ops.device(caching_device): self._cached_value = array_ops.identity(value) else: self._cached_value = None ops.add_to_collections(collections, self)
def create_op(self, *args, **kwargs): """Creates an `Operation`. For operations of the following form orig_value = op(*args, **kwargs) this function constructs the following subgraph : v = Variable() if v is not initialized: orig_value = op(*args, **kwargs) v.assign(orig_value) # Initializes v return orig_value else: return v The above transformation is not performed and the original op is returned as is if any of the following is true: * `_return_as_is` flag is set to true. * op_type is listed in _PASS_THROUGH_OPS * op has no outputs. * One of the op's return value has a ref type. Args: *args: Arguments for create_op() **kwargs: Keyword arguments for create_op(). Refer to tensorflow.python.framework.ops.Graph.create_op() for the mandatory and optional arguments. Returns: An Operation. Raises: UnimplementedError: if output type is a reference and the op's type is not one of the supported types in `_REF_OPS_WHITELIST`. """ op_type = kwargs['op_type'] if 'op_type' in kwargs else args[0] output_dtypes = kwargs['dtypes'] if 'dtypes' in kwargs else args[2] output_dtypes = [dtypes.as_dtype(d) for d in output_dtypes] if self._return_as_is or op_type in _PASS_THROUGH_OPS: return self._wrap( super(ImperativeGraph, self).create_op(*args, **kwargs)) if not output_dtypes: return self._wrap( super(ImperativeGraph, self).create_op(*args, **kwargs)) output_has_ref = any([dtype._is_ref_dtype for dtype in output_dtypes]) # pylint: disable=protected-access if output_has_ref: if op_type not in _REF_OPS_WHITELIST: raise errors.UnimplementedError( None, None, op_type + ' op not supported in ' 'imperative graph') ret = super(ImperativeGraph, self).create_op(*args, **kwargs) if self._in_variable_creation: if op_type == 'Assign': self.add_pending_init(ret) return self._wrap(ret) with self.return_as_is(): # Declares the variables to hold the output values of this op. op_output_var = [ state_ops.variable_op_v2(tensor_shape.TensorShape(None), dtype, container=self._name) for dtype in output_dtypes ] # Ops to free the resources used by the temporary cache variables. # The following two ops are created for each cache variable, # having no control dependencies on any other ops : # var_handle_op ----> destroy_resource_op for dtype, v in zip(output_dtypes, op_output_var): with ops.control_dependencies(None): self._variable_cleanup_ops += [ gen_resource_variable_ops.destroy_resource_op( gen_resource_variable_ops.var_handle_op( dtype, tensor_shape.TensorShape(None), container=self._name, shared_name=v.op.name), ignore_lookup_error=True) ] # Create the conditional to run the original op only when the variable # corresponding to the first output is not initialized. inited = state_ops.is_variable_initialized(op_output_var[0]) v_f, v_t = control_flow_ops.ref_switch(op_output_var[0], inited) # pylint: disable=protected-access v_f_op = gen_array_ops._ref_identity(v_f) v_t_op = gen_array_ops._ref_identity(v_t) # pylint: enable=protected-access with ops.control_dependencies([v_f_op.op]): # Create the original op orig_op = self._wrap( super(ImperativeGraph, self).create_op(*args, **kwargs)) shapes = [val.get_shape() for val in orig_op.outputs] controls = [] for var, val in zip(op_output_var, orig_op.outputs): if (not val.get_shape().is_fully_defined() or val.get_shape().num_elements() > 0): assign_op = state_ops.assign(var, val, validate_shape=False) assign_op.set_shape(val.get_shape()) controls.append(assign_op) values = [] if len(controls) > 1: if control_flow_ops.IsSwitch(orig_op): # pylint: disable=protected-access controls = gen_control_flow_ops._ref_merge(controls) # pylint: enable=protected-access else: controls = control_flow_ops.tuple(controls) for var, val in zip(op_output_var, orig_op.outputs): with ops.control_dependencies(controls): with self.colocate_with(v_f_op): real_val = array_ops.identity(val) with ops.control_dependencies([v_t_op.op]): with self.colocate_with(v_t_op): stored_val = array_ops.identity(var) stored_val.set_shape(val.get_shape()) real_val, _ = control_flow_ops.merge( [real_val, stored_val]) real_val.op.node_def.attr['_gradient_op_type'].CopyFrom( attr_value_pb2.AttrValue( s=compat.as_bytes(self._merge_op_type))) values.append(real_val) for i, _ in enumerate(shapes): values[i].set_shape(shapes[i]) self._outputs_map[orig_op.name] = values try: self._gradient_function_map[ orig_op.name] = ops.get_gradient_function(orig_op) except (KeyError, LookupError): pass else: orig_op.node_def.attr['_gradient_op_type'].CopyFrom( attr_value_pb2.AttrValue( s=compat.as_bytes(self._imperative_op_type))) return MultiOutputOperation(values)
def create_op(self, *args, **kwargs): """Creates an `Operation`. For operations of the following form orig_value = op(*args, **kwargs) this function constructs the following subgraph : v = Variable() if v is not initialized: orig_value = op(*args, **kwargs) v.assign(orig_value) # Initializes v return orig_value else: return v The above transformation is not performed and the original op is returned as is if any of the following is true: * `_return_as_is` flag is set to true. * op_type is listed in _PASS_THROUGH_OPS * op has no outputs. * One of the op's return value has a ref type. Args: *args: Arguments for create_op() **kwargs: Keyword arguments for create_op(). Refer to tensorflow.python.framework.ops.Graph.create_op() for the mandatory and optional arguments. Returns: An Operation. Raises: UnimplementedError: if output type is a reference and the op's type is not one of the supported types in `_REF_OPS_WHITELIST`. """ op_type = kwargs['op_type'] if 'op_type' in kwargs else args[0] output_dtypes = kwargs['dtypes'] if 'dtypes' in kwargs else args[2] output_dtypes = [dtypes.as_dtype(d) for d in output_dtypes] if self._return_as_is or op_type in _PASS_THROUGH_OPS: return self._wrap(super(ImperativeGraph, self).create_op(*args, **kwargs)) if not output_dtypes: return self._wrap( super(ImperativeGraph, self).create_op(*args, **kwargs)) output_has_ref = any([dtype._is_ref_dtype for dtype in output_dtypes]) # pylint: disable=protected-access if output_has_ref: if op_type not in _REF_OPS_WHITELIST: raise errors.UnimplementedError(None, None, op_type + ' op not supported in ' 'imperative graph') ret = super(ImperativeGraph, self).create_op(*args, **kwargs) if self._in_variable_creation: if op_type == 'Assign': self.add_pending_init(ret) return self._wrap(ret) with self.return_as_is(): # Declares the variables to hold the output values of this op. op_output_var = [state_ops.variable_op_v2( tensor_shape.TensorShape(None), dtype, container=self._name) for dtype in output_dtypes] # Ops to free the resources used by the temporary cache variables. # The following two ops are created for each cache variable, # having no control dependencies on any other ops : # var_handle_op ----> destroy_resource_op for dtype, v in zip(output_dtypes, op_output_var): with ops.control_dependencies(None): self._variable_cleanup_ops += [ gen_resource_variable_ops.destroy_resource_op( gen_resource_variable_ops.var_handle_op( dtype, tensor_shape.TensorShape(None), container=self._name, shared_name=v.op.name), ignore_lookup_error=True)] # Create the conditional to run the original op only when the variable # corresponding to the first output is not initialized. inited = state_ops.is_variable_initialized(op_output_var[0]) v_f, v_t = control_flow_ops.ref_switch(op_output_var[0], inited) # pylint: disable=protected-access v_f_op = gen_array_ops._ref_identity(v_f) v_t_op = gen_array_ops._ref_identity(v_t) # pylint: enable=protected-access with ops.control_dependencies([v_f_op.op]): # Create the original op orig_op = self._wrap( super(ImperativeGraph, self).create_op(*args, **kwargs)) shapes = [val.get_shape() for val in orig_op.outputs] controls = [] for var, val in zip(op_output_var, orig_op.outputs): if (not val.get_shape().is_fully_defined() or val.get_shape().num_elements() > 0): assign_op = state_ops.assign(var, val, validate_shape=False) assign_op.set_shape(val.get_shape()) controls.append(assign_op) values = [] if len(controls) > 1: if control_flow_ops.IsSwitch(orig_op): # pylint: disable=protected-access controls = gen_control_flow_ops._ref_merge(controls) # pylint: enable=protected-access else: controls = control_flow_ops.tuple(controls) for var, val in zip(op_output_var, orig_op.outputs): with ops.control_dependencies(controls): with self.colocate_with(v_f_op): real_val = array_ops.identity(val) with ops.control_dependencies([v_t_op.op]): with self.colocate_with(v_t_op): stored_val = array_ops.identity(var) stored_val.set_shape(val.get_shape()) real_val, _ = control_flow_ops.merge([real_val, stored_val]) real_val.op.node_def.attr['_gradient_op_type'].CopyFrom( attr_value_pb2.AttrValue(s=compat.as_bytes(self._merge_op_type))) values.append(real_val) for i, _ in enumerate(shapes): values[i].set_shape(shapes[i]) self._outputs_map[orig_op.name] = values try: self._gradient_function_map[orig_op.name] = ops.get_gradient_function( orig_op) except (KeyError, LookupError): pass else: orig_op.node_def.attr['_gradient_op_type'].CopyFrom( attr_value_pb2.AttrValue( s=compat.as_bytes(self._imperative_op_type))) return MultiOutputOperation(values, orig_op)
def __init__(self, initial_value=None, name=None, trainable=True, collections=None, dtype=None, shape=None): """Creates a variable. Args: initial_value: An `Output` or Python object convertible to an `Output` representing the initial value of this variable. name: The name of this variable. Automatically uniquified. trainable: Whether the global read of this variable will be used for training. collections: Additional collections to which the `read` operation for this variable is to be added. Defaults to []. dtype: The type of this variable. Can be omitted if it can be deduced from the initial_value. If different from the type of the initial value it will be cast to this type. shape: The shape of this variable. Only specify if there is no initial value but shape inference is desired. """ if initial_value is not None: initial_value = ops.convert_to_tensor(initial_value) if dtype is None: assert initial_value is not None, ( "Trying to create a resource variable " "with no dtype or initial value. At" " least one of these must be set.") dtype = initial_value.dtype elif initial_value is not None: initial_value = math_ops.cast(initial_value, dtype) if shape is None: if initial_value is not None: shape = initial_value.get_shape().as_proto() else: shape = tensor_shape.unknown_shape() else: shape = tensor_shape.as_shape(shape) self._dtype = dtype with ops.name_scope(name, "Variable", [initial_value]) as name: self._handle = gen_resource_variable_ops.var_handle_op( shared_name=name, name=name, dtype=dtype, shape=shape) with ops.name_scope("IsInitialized"): self._is_initialized_op = ( gen_resource_variable_ops.var_is_initialized_op( self._handle)) if initial_value is not None: with ops.name_scope("Create"): self._initialize_op = gen_resource_variable_ops.create_variable_op( self._handle, initial_value) resources.register_resource(self._handle, self._initialize_op, self._is_initialized_op) with ops.name_scope("Read"): self._value = gen_resource_variable_ops.read_variable_op( self._handle, dtype=self._dtype) _register_variable_read(self._value, trainable=trainable, collections=collections)