def define_local(self, name, value): if name in self._locals: raise Exception("Local already defined: %s" % name) should_wrap_in_var = False if self._wrap_locals_in_vars: if isinstance(value, tf.Tensor): should_wrap_in_var = True # HACK(adamb) Unwrapping in here really isn't great, since auto-unwrapping can create unexpected behavior. if isinstance(value, RetvalBag) and value.len() == 1: if isinstance(value.get(None), tf.Tensor): should_wrap_in_var = True value = value.get(None) if should_wrap_in_var: variable = state_ops.variable_op_v2(value.get_shape(), value.dtype.base_dtype) with tf.control_dependencies(None): value = tf.identity( tf.cond(tf.is_variable_initialized(variable), lambda: variable, lambda: tf.assign(variable, value))) print("value", value) self._locals[name] = value return value
def _init_from_args(self, initial_value=None, trainable=True, collections=None, validate_shape=True, caching_device=None, name=None, dtype=None, expected_shape=None): """Creates a new variable from arguments. Args: initial_value: A `Tensor`, or Python object convertible to a `Tensor`, which is the initial value for the Variable. The initial value must have a shape specified unless `validate_shape` is set to False. Can also be a callable with no argument that returns the initial value when called. (Note that initializer functions from init_ops.py must first be bound to a shape before being used here.) trainable: If `True`, the default, also adds the variable to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default list of variables to use by the `Optimizer` classes. collections: List of graph collections keys. The new variable is added to these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. validate_shape: If `False`, allows the variable to be initialized with a value of unknown shape. If `True`, the default, the shape of `initial_value` must be known. caching_device: Optional device string or function describing where the Variable should be cached for reading. Defaults to the Variable's device. If not `None`, caches on another device. Typical use is to cache on the device where the Ops using the Variable reside, to deduplicate copying through `Switch` and other conditional statements. name: Optional name for the variable. Defaults to `'Variable'` and gets uniquified automatically. dtype: If set, initial_value will be converted to the given type. If None, either the datatype will be kept (if initial_value is a Tensor) or float32 will be used (if it is a Python object convertible to a Tensor). expected_shape: Deprecated. Ignored. Raises: ValueError: If the initial value is not specified, or does not have a shape and `validate_shape` is `True`. """ _ = expected_shape if initial_value is None: raise ValueError("initial_value must be specified.") init_from_fn = callable(initial_value) if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] if not isinstance(collections, (list, tuple, set)): raise ValueError( "collections argument to Variable constructor must be a list, tuple, " "or set. Got %s of type %s" % (collections, type(collections))) if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] with ops.control_dependencies(None): with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as name: if init_from_fn: # Use attr_scope and device(None) to simulate the behavior of # colocate_with when the variable we want to colocate with doesn't # yet exist. true_name = ops._name_from_scope_name(name) attr = attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue( s=[compat.as_bytes("loc:@%s" % true_name)])) # pylint: disable=protected-access with ops.get_default_graph()._attr_scope({"_class": attr}): with ops.name_scope("Initializer"), ops.device(None): self._initial_value = ops.convert_to_tensor( initial_value(), name="initial_value", dtype=dtype) shape = (self._initial_value.get_shape() if validate_shape else tensor_shape.unknown_shape()) self._variable = state_ops.variable_op_v2( shape, self._initial_value.dtype.base_dtype, name=name) # Or get the initial value from a Tensor or Python object. else: self._initial_value = ops.convert_to_tensor( initial_value, name="initial_value", dtype=dtype) shape = (self._initial_value.get_shape() if validate_shape else tensor_shape.unknown_shape()) # In this case, the variable op can't be created until after the # initial_value has been converted to a Tensor with a known type. self._variable = state_ops.variable_op_v2( shape, self._initial_value.dtype.base_dtype, name=name) # Manually overrides the variable's shape with the initial value's. if validate_shape: initial_value_shape = self._initial_value.get_shape() if not initial_value_shape.is_fully_defined(): raise ValueError("initial_value must have a shape specified: %s" % self._initial_value) # Assigns initial value. self._initializer_op = state_ops.assign( self._variable, self._initial_value, validate_shape=validate_shape).op # TODO(vrv): Change this class to not take caching_device, but # to take the op to colocate the snapshot with, so we can use # colocation rather than devices. if caching_device is not None: with ops.device(caching_device): self._snapshot = array_ops.identity(self._variable, name="read") else: with ops.colocate_with(self._variable.op): self._snapshot = array_ops.identity(self._variable, name="read") ops.add_to_collections(collections, self) self._caching_device = caching_device self._save_slice_info = None
def create_op(self, *args, **kwargs): """Creates an `Operation`. For operations of the following form orig_value = op(*args, **kwargs) this function constructs the following subgraph : v = Variable() if v is not initialized: orig_value = op(*args, **kwargs) v.assign(orig_value) # Initializes v return orig_value else: return v The above transformation is not performed and the original op is returned as is if any of the following is true: * `_return_as_is` flag is set to true. * op_type is listed in _PASS_THROUGH_OPS * op has no outputs. * One of the op's return value has a ref type. Args: *args: Arguments for create_op() **kwargs: Keyword arguments for create_op(). Refer to tensorflow.python.framework.ops.Graph.create_op() for the mandatory and optional arguments. Returns: An Operation. Raises: UnimplementedError: if output type is a reference and the op's type is not one of the supported types in `_REF_OPS_WHITELIST`. """ op_type = kwargs['op_type'] if 'op_type' in kwargs else args[0] output_dtypes = kwargs['dtypes'] if 'dtypes' in kwargs else args[2] output_dtypes = [dtypes.as_dtype(d) for d in output_dtypes] if self._return_as_is or op_type in _PASS_THROUGH_OPS: return self._wrap( super(ImperativeGraph, self).create_op(*args, **kwargs)) if not output_dtypes: return self._wrap( super(ImperativeGraph, self).create_op(*args, **kwargs)) output_has_ref = any([dtype._is_ref_dtype for dtype in output_dtypes]) # pylint: disable=protected-access if output_has_ref: if op_type not in _REF_OPS_WHITELIST: raise errors.UnimplementedError( None, None, op_type + ' op not supported in ' 'imperative graph') ret = super(ImperativeGraph, self).create_op(*args, **kwargs) if self._in_variable_creation: if op_type == 'Assign': self.add_pending_init(ret) return self._wrap(ret) with self.return_as_is(): # Declares the variables to hold the output values of this op. op_output_var = [ state_ops.variable_op_v2(tensor_shape.TensorShape(None), dtype, container=self._name) for dtype in output_dtypes ] # Ops to free the resources used by the temporary cache variables. # The following two ops are created for each cache variable, # having no control dependencies on any other ops : # var_handle_op ----> destroy_resource_op for dtype, v in zip(output_dtypes, op_output_var): with ops.control_dependencies(None): self._variable_cleanup_ops += [ gen_resource_variable_ops.destroy_resource_op( gen_resource_variable_ops.var_handle_op( dtype, tensor_shape.TensorShape(None), container=self._name, shared_name=v.op.name), ignore_lookup_error=True) ] # Create the conditional to run the original op only when the variable # corresponding to the first output is not initialized. inited = state_ops.is_variable_initialized(op_output_var[0]) v_f, v_t = control_flow_ops.ref_switch(op_output_var[0], inited) # pylint: disable=protected-access v_f_op = gen_array_ops._ref_identity(v_f) v_t_op = gen_array_ops._ref_identity(v_t) # pylint: enable=protected-access with ops.control_dependencies([v_f_op.op]): # Create the original op orig_op = self._wrap( super(ImperativeGraph, self).create_op(*args, **kwargs)) shapes = [val.get_shape() for val in orig_op.outputs] controls = [] for var, val in zip(op_output_var, orig_op.outputs): if (not val.get_shape().is_fully_defined() or val.get_shape().num_elements() > 0): assign_op = state_ops.assign(var, val, validate_shape=False) assign_op.set_shape(val.get_shape()) controls.append(assign_op) values = [] if len(controls) > 1: if control_flow_ops.IsSwitch(orig_op): # pylint: disable=protected-access controls = gen_control_flow_ops._ref_merge(controls) # pylint: enable=protected-access else: controls = control_flow_ops.tuple(controls) for var, val in zip(op_output_var, orig_op.outputs): with ops.control_dependencies(controls): with self.colocate_with(v_f_op): real_val = array_ops.identity(val) with ops.control_dependencies([v_t_op.op]): with self.colocate_with(v_t_op): stored_val = array_ops.identity(var) stored_val.set_shape(val.get_shape()) real_val, _ = control_flow_ops.merge( [real_val, stored_val]) real_val.op.node_def.attr['_gradient_op_type'].CopyFrom( attr_value_pb2.AttrValue( s=compat.as_bytes(self._merge_op_type))) values.append(real_val) for i, _ in enumerate(shapes): values[i].set_shape(shapes[i]) self._outputs_map[orig_op.name] = values try: self._gradient_function_map[ orig_op.name] = ops.get_gradient_function(orig_op) except (KeyError, LookupError): pass else: orig_op.node_def.attr['_gradient_op_type'].CopyFrom( attr_value_pb2.AttrValue( s=compat.as_bytes(self._imperative_op_type))) return MultiOutputOperation(values)
def create_op(self, *args, **kwargs): """Creates an `Operation`. For operations of the following form orig_value = op(*args, **kwargs) this function constructs the following subgraph : v = Variable() if v is not initialized: orig_value = op(*args, **kwargs) v.assign(orig_value) # Initializes v return orig_value else: return v The above transformation is not performed and the original op is returned as is if any of the following is true: * `_return_as_is` flag is set to true. * op_type is listed in _PASS_THROUGH_OPS * op has no outputs. * One of the op's return value has a ref type. Args: *args: Arguments for create_op() **kwargs: Keyword arguments for create_op(). Refer to tensorflow.python.framework.ops.Graph.create_op() for the mandatory and optional arguments. Returns: An Operation. Raises: UnimplementedError: if output type is a reference and the op's type is not one of the supported types in `_REF_OPS_WHITELIST`. """ op_type = kwargs['op_type'] if 'op_type' in kwargs else args[0] output_dtypes = kwargs['dtypes'] if 'dtypes' in kwargs else args[2] output_dtypes = [dtypes.as_dtype(d) for d in output_dtypes] if self._return_as_is or op_type in _PASS_THROUGH_OPS: return self._wrap(super(ImperativeGraph, self).create_op(*args, **kwargs)) if not output_dtypes: return self._wrap( super(ImperativeGraph, self).create_op(*args, **kwargs)) output_has_ref = any([dtype._is_ref_dtype for dtype in output_dtypes]) # pylint: disable=protected-access if output_has_ref: if op_type not in _REF_OPS_WHITELIST: raise errors.UnimplementedError(None, None, op_type + ' op not supported in ' 'imperative graph') ret = super(ImperativeGraph, self).create_op(*args, **kwargs) if self._in_variable_creation: if op_type == 'Assign': self.add_pending_init(ret) return self._wrap(ret) with self.return_as_is(): # Declares the variables to hold the output values of this op. op_output_var = [state_ops.variable_op_v2( tensor_shape.TensorShape(None), dtype, container=self._name) for dtype in output_dtypes] # Ops to free the resources used by the temporary cache variables. # The following two ops are created for each cache variable, # having no control dependencies on any other ops : # var_handle_op ----> destroy_resource_op for dtype, v in zip(output_dtypes, op_output_var): with ops.control_dependencies(None): self._variable_cleanup_ops += [ gen_resource_variable_ops.destroy_resource_op( gen_resource_variable_ops.var_handle_op( dtype, tensor_shape.TensorShape(None), container=self._name, shared_name=v.op.name), ignore_lookup_error=True)] # Create the conditional to run the original op only when the variable # corresponding to the first output is not initialized. inited = state_ops.is_variable_initialized(op_output_var[0]) v_f, v_t = control_flow_ops.ref_switch(op_output_var[0], inited) # pylint: disable=protected-access v_f_op = gen_array_ops._ref_identity(v_f) v_t_op = gen_array_ops._ref_identity(v_t) # pylint: enable=protected-access with ops.control_dependencies([v_f_op.op]): # Create the original op orig_op = self._wrap( super(ImperativeGraph, self).create_op(*args, **kwargs)) shapes = [val.get_shape() for val in orig_op.outputs] controls = [] for var, val in zip(op_output_var, orig_op.outputs): if (not val.get_shape().is_fully_defined() or val.get_shape().num_elements() > 0): assign_op = state_ops.assign(var, val, validate_shape=False) assign_op.set_shape(val.get_shape()) controls.append(assign_op) values = [] if len(controls) > 1: if control_flow_ops.IsSwitch(orig_op): # pylint: disable=protected-access controls = gen_control_flow_ops._ref_merge(controls) # pylint: enable=protected-access else: controls = control_flow_ops.tuple(controls) for var, val in zip(op_output_var, orig_op.outputs): with ops.control_dependencies(controls): with self.colocate_with(v_f_op): real_val = array_ops.identity(val) with ops.control_dependencies([v_t_op.op]): with self.colocate_with(v_t_op): stored_val = array_ops.identity(var) stored_val.set_shape(val.get_shape()) real_val, _ = control_flow_ops.merge([real_val, stored_val]) real_val.op.node_def.attr['_gradient_op_type'].CopyFrom( attr_value_pb2.AttrValue(s=compat.as_bytes(self._merge_op_type))) values.append(real_val) for i, _ in enumerate(shapes): values[i].set_shape(shapes[i]) self._outputs_map[orig_op.name] = values try: self._gradient_function_map[orig_op.name] = ops.get_gradient_function( orig_op) except (KeyError, LookupError): pass else: orig_op.node_def.attr['_gradient_op_type'].CopyFrom( attr_value_pb2.AttrValue( s=compat.as_bytes(self._imperative_op_type))) return MultiOutputOperation(values, orig_op)
def _init_from_args(self, initial_value=None, trainable=True, collections=None, validate_shape=True, caching_device=None, name=None, dtype=None, expected_shape=None, constraint=None): _ = expected_shape if initial_value is None: raise ValueError("initial_value must be specified.") init_from_fn = callable(initial_value) if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] if not isinstance(collections, (list, tuple, set)): raise ValueError( "collections argument to Variable constructor must be a list, tuple, " "or set. Got %s of type %s" % (collections, type(collections))) if constraint is not None and not callable(constraint): raise ValueError("The `constraint` argument must be a callable.") # Store the graph key so optimizers know how to only retrieve variables from # this graph. self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access if isinstance(initial_value, checkpointable.CheckpointInitialValue): self._maybe_initialize_checkpointable() self._update_uid = initial_value.checkpoint_position.restore_uid initial_value = initial_value.wrapped_value if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] with ops.init_scope(): # Ensure that we weren't lifted into the eager context. if context.executing_eagerly(): raise RuntimeError( "tf.Variable not supported when eager execution is enabled. " "Please use tf.contrib.eager.Variable instead") with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as name: if init_from_fn: # Use attr_scope and device(None) to simulate the behavior of # colocate_with when the variable we want to colocate with doesn't # yet exist. true_name = ops._name_from_scope_name(name) # pylint: disable=protected-access attr = attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue( s=[compat.as_bytes("loc:@%s" % true_name)])) # pylint: disable=protected-access with ops.get_default_graph()._attr_scope({"_class": attr}): with ops.name_scope("Initializer"), ops.device(None): self._initial_value = ops.convert_to_tensor( initial_value(), name="initial_value", dtype=dtype) shape = (self._initial_value.get_shape() if validate_shape else tensor_shape.unknown_shape()) self._variable = state_ops.variable_op_v2( shape, self._initial_value.dtype.base_dtype, name=name) # pylint: enable=protected-access # Or get the initial value from a Tensor or Python object. else: self._initial_value = ops.convert_to_tensor( initial_value, name="initial_value", dtype=dtype) # pylint: disable=protected-access if self._initial_value.op._get_control_flow_context() is not None: raise ValueError( "Initializer for variable %s is from inside a control-flow " "construct, such as a loop or conditional. When creating a " "variable inside a loop or conditional, use a lambda as the " "initializer." % name) # pylint: enable=protected-access shape = (self._initial_value.get_shape() if validate_shape else tensor_shape.unknown_shape()) # In this case, the variable op can't be created until after the # initial_value has been converted to a Tensor with a known type. self._variable = state_ops.variable_op_v2( shape, self._initial_value.dtype.base_dtype, name=name) # Manually overrides the variable's shape with the initial value's. if validate_shape: initial_value_shape = self._initial_value.get_shape() if not initial_value_shape.is_fully_defined(): raise ValueError("initial_value must have a shape specified: %s" % self._initial_value) # If 'initial_value' makes use of other variables, make sure we don't # have an issue if these other variables aren't initialized first by # using their initialized_value() method. self._initializer_op = state_ops.assign( self._variable, self._try_guard_against_uninitialized_dependencies( self._initial_value), validate_shape=validate_shape).op # TODO(vrv): Change this class to not take caching_device, but # to take the op to colocate the snapshot with, so we can use # colocation rather than devices. if caching_device is not None: with ops.device(caching_device): self._snapshot = array_ops.identity(self._variable, name="read") else: with ops.colocate_with(self._variable.op): self._snapshot = array_ops.identity(self._variable, name="read") ops.add_to_collections(collections, self) self._caching_device = caching_device self._save_slice_info = None self._constraint = constraint