def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring value_list = [] for i, d in enumerate(devices): with ops.device(d): if i > 0: # Give replicas meaningful distinct names: var0name = value_list[0].name.split(":")[0] # We append a / to variable names created on replicas with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Initialize replicas with the same value: if context.executing_eagerly() or ops.inside_function(): with ops.init_scope(): kwargs["initial_value"] = array_ops.identity( value_list[0].value()) else: def initial_value_fn(device=d): with ops.device(device): return array_ops.identity(value_list[0].initial_value) kwargs["initial_value"] = initial_value_fn with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): v = next_creator(*args, **kwargs) assert not isinstance(v, values.TPUMirroredVariable) value_list.append(v) return value_list
def initial_value_fn(device=d): if context.executing_eagerly() or ops.inside_function(): init_value = value_list[0].value() return array_ops.identity(init_value) else: with ops.device(device): init_value = value_list[0].initial_value return array_ops.identity(init_value)
def create_file_writer_v2(logdir, max_queue=None, flush_millis=None, filename_suffix=None, name=None): """Creates a summary file writer for the given log directory. Args: logdir: a string specifying the directory in which to write an event file. max_queue: the largest number of summaries to keep in a queue; will flush once the queue gets bigger than this. Defaults to 10. flush_millis: the largest interval between flushes. Defaults to 120,000. filename_suffix: optional suffix for the event file name. Defaults to `.v2`. name: a name for the op that creates the writer. Returns: A SummaryWriter object. """ if logdir is None: raise ValueError("logdir cannot be None") inside_function = ops.inside_function() with ops.name_scope(name, "create_file_writer") as scope, ops.device("cpu:0"): # Run init inside an init_scope() to hoist it out of tf.functions. with ops.init_scope(): if context.executing_eagerly(): _check_create_file_writer_args( inside_function, logdir=logdir, max_queue=max_queue, flush_millis=flush_millis, filename_suffix=filename_suffix) logdir = ops.convert_to_tensor(logdir, dtype=dtypes.string) if max_queue is None: max_queue = constant_op.constant(10) if flush_millis is None: flush_millis = constant_op.constant(2 * 60 * 1000) if filename_suffix is None: filename_suffix = constant_op.constant(".v2") # Prepend the PID and a process-local UID to the filename suffix to avoid # filename collisions within the machine (the filename already contains # the hostname to avoid cross-machine collisions). unique_prefix = constant_op.constant(".%s.%s" % (os.getpid(), ops.uid())) filename_suffix = unique_prefix + filename_suffix # Use a unique shared_name to prevent resource sharing. if context.executing_eagerly(): shared_name = context.shared_name() else: shared_name = ops.name_from_scope_name(scope) # pylint: disable=protected-access return ResourceSummaryWriter( shared_name=shared_name, init_op_fn=functools.partial( gen_summary_ops.create_summary_file_writer, logdir=logdir, max_queue=max_queue, flush_millis=flush_millis, filename_suffix=filename_suffix), name=name, v2=True)
def __init__(self, root): if (not context.executing_eagerly() and not ops.inside_function()): saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary() else: saveables_cache = None super(_AugmentedGraphView, self).__init__(root, saveables_cache) # Object -> (name -> dep) self._extra_dependencies = object_identity.ObjectIdentityDictionary() self._functions = object_identity.ObjectIdentityDictionary()
def initial_value_fn(device=d): if context.executing_eagerly( ) or ops.inside_function(): return array_ops.identity( value_list[0].value()) else: with ops.device(device): return array_ops.identity( value_list[0].initial_value)
def _tpu_run(strategy, fn, args, kwargs): """Common implementation of TPUStrategy.experimental_run_v2.""" if context.executing_eagerly() and not ops.inside_function(): raise NotImplementedError( "Eager mode not supported in TPUStrategy outside TF functions.") if kwargs is None: kwargs = {} # Used to re-structure flattened output tensors from `tpu.replicate()` # into a structured format. result = [[]] def replicated_fn(replica_id, replica_args, replica_kwargs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id): result[0] = fn(*replica_args, **replica_kwargs) return result[0] replicate_inputs = [] # By replica. for i in range(strategy.num_replicas_in_sync): replicate_inputs.append( [constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, args), values.select_replica(i, kwargs)]) # Construct and pass `maximum_shapes` so that we could support dynamic # shapes using dynamic padder. if replicate_inputs: maximum_shapes = [] flattened_list = nest.flatten(replicate_inputs[0]) for input_tensor in flattened_list: maximum_shapes.append(input_tensor.get_shape()) maximum_shapes = nest.pack_sequence_as(replicate_inputs[0], maximum_shapes) else: maximum_shapes = None with strategy.scope(): replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs, maximum_shapes=maximum_shapes) # Remove all no ops that may have been added during 'tpu.replicate()' if isinstance(result[0], list): result[0] = [ output for output in result[0] if tensor_util.is_tensor(output) ] # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_output)) for replica_output in replicate_outputs ] device_map = strategy.extended._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def _create_tpu_mirrored_variable( # pylint: disable=missing-docstring strategy, device_map, logical_device, real_mirrored_creator, *args, **kwargs): # Figure out what collections this variable should be added to. # We'll add the TPUMirroredVariable to those collections instead. var_collections = kwargs.pop("collections", None) if var_collections is None: var_collections = [ops.GraphKeys.GLOBAL_VARIABLES] kwargs["collections"] = [] # TODO(jhseu): Should we have different behavior for different # synchronization settings? # Get aggregation value # TODO(jhseu): Support aggregation in a replica context. aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) if aggregation not in [ vs.VariableAggregation.NONE, vs.VariableAggregation.SUM, vs.VariableAggregation.MEAN, vs.VariableAggregation.ONLY_FIRST_REPLICA, ]: raise ValueError( "Invalid variable aggregation mode: {} for variable: {}".format( aggregation, kwargs["name"])) # Ignore user-specified caching device, not needed for mirrored variables. kwargs.pop("caching_device", None) # TODO(josh11b,apassos): It would be better if variable initialization # was never recorded on the tape instead of having to do this manually # here. with tape.stop_recording(): devices = device_map.logical_to_actual_devices(logical_device) value_list = real_mirrored_creator(devices, *args, **kwargs) result = values.TPUMirroredVariable(strategy, device_map, value_list, aggregation, logical_device=logical_device) if not (context.executing_eagerly() or ops.inside_function()): g = ops.get_default_graph() # If "trainable" is True, next_creator() will add the member variables # to the TRAINABLE_VARIABLES collection, so we manually remove # them and replace with the MirroredVariable. We can't set # "trainable" to False for next_creator() since that causes functions # like implicit_gradients to skip those variables. if kwargs.get("trainable", True): var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) for v in value_list: l.remove(v) g.add_to_collections(var_collections, result) return result
def trace_export(name, step=None, profiler_outdir=None): """Stops and exports the active trace as a Summary and/or profile file. Stops the trace and exports all metadata collected during the trace to the default SummaryWriter, if one has been set. Args: name: A name for the summary to be written. step: Explicit `int64`-castable monotonic step value for this summary. If omitted, this defaults to `tf.summary.experimental.get_step()`, which must not be None. profiler_outdir: Output directory for profiler. This is only used when the profiler was enabled when the trace was started. In that case, if there is a logdir-based default SummaryWriter, this defaults to the same directory, but otherwise the argument must be passed. Raises: ValueError: if a default writer exists, but no step was provided and `tf.summary.experimental.get_step()` is None. """ global _current_trace_context if ops.inside_function(): logging.warn("Cannot export trace inside a tf.function.") return if not context.executing_eagerly(): logging.warn("Can only export trace while executing eagerly.") return with _current_trace_context_lock: if _current_trace_context is None: raise ValueError("Must enable trace before export.") graph, profiler = _current_trace_context # pylint: disable=redefined-outer-name if profiler_outdir is None \ and isinstance(_summary_state.writer, ResourceSummaryWriter): logdir = _summary_state.writer._metadata.get("logdir") # pylint: disable=protected-access if logdir is not None: profiler_outdir = logdir if profiler and profiler_outdir is None: raise ValueError("Must set profiler_outdir or " "enable summary writer with logdir.") run_meta = context.context().export_run_metadata() if graph and not profiler: run_metadata_graphs(name, run_meta, step) else: run_metadata(name, run_meta, step) if profiler: _profiler.save(profiler_outdir, _profiler.stop()) trace_off()
def _create_tpu_mirrored_variable( # pylint: disable=missing-docstring strategy, device_map, logical_device, real_mirrored_creator, *args, **kwargs): # Figure out what collections this variable should be added to. # We'll add the TPUMirroredVariable to those collections instead. var_collections = kwargs.pop("collections", None) if var_collections is None: var_collections = [ops.GraphKeys.GLOBAL_VARIABLES] kwargs["collections"] = [] # TODO(jhseu): Should we have different behavior for different # synchronization settings? # Get aggregation value # TODO(jhseu): Support aggregation in a replica context. aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) if aggregation not in [ vs.VariableAggregation.NONE, vs.VariableAggregation.SUM, vs.VariableAggregation.MEAN, vs.VariableAggregation.ONLY_FIRST_REPLICA, ]: raise ValueError("Invalid variable aggregation mode: {} for variable: {}" .format(aggregation, kwargs["name"])) # Ignore user-specified caching device, not needed for mirrored variables. kwargs.pop("caching_device", None) # TODO(josh11b,apassos): It would be better if variable initialization # was never recorded on the tape instead of having to do this manually # here. with tape.stop_recording(): devices = device_map.logical_to_actual_devices(logical_device) value_list = real_mirrored_creator(devices, *args, **kwargs) result = values.TPUMirroredVariable( strategy, device_map, value_list, aggregation, logical_device=logical_device) if not (context.executing_eagerly() or ops.inside_function()): g = ops.get_default_graph() # If "trainable" is True, next_creator() will add the member variables # to the TRAINABLE_VARIABLES collection, so we manually remove # them and replace with the MirroredVariable. We can't set # "trainable" to False for next_creator() since that causes functions # like implicit_gradients to skip those variables. if kwargs.get("trainable", True): var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) for v in value_list: l.remove(v) g.add_to_collections(var_collections, result) return result
def create_file_writer_v2(logdir, max_queue=None, flush_millis=None, filename_suffix=None, name=None): """Creates a summary file writer for the given log directory. Args: logdir: a string specifying the directory in which to write an event file. max_queue: the largest number of summaries to keep in a queue; will flush once the queue gets bigger than this. Defaults to 10. flush_millis: the largest interval between flushes. Defaults to 120,000. filename_suffix: optional suffix for the event file name. Defaults to `.v2`. name: a name for the op that creates the writer. Returns: A SummaryWriter object. """ if logdir is None: raise ValueError("logdir cannot be None") inside_function = ops.inside_function() with ops.name_scope(name, "create_file_writer") as scope, ops.device("cpu:0"): # Run init inside an init_scope() to hoist it out of tf.functions. with ops.init_scope(): if context.executing_eagerly(): _check_create_file_writer_args(inside_function, logdir=logdir, max_queue=max_queue, flush_millis=flush_millis, filename_suffix=filename_suffix) logdir = ops.convert_to_tensor(logdir, dtype=dtypes.string) if max_queue is None: max_queue = constant_op.constant(10) if flush_millis is None: flush_millis = constant_op.constant(2 * 60 * 1000) if filename_suffix is None: filename_suffix = constant_op.constant(".v2") # Use a unique shared_name to prevent resource sharing. if context.executing_eagerly(): shared_name = context.shared_name() else: shared_name = ops.name_from_scope_name(scope) # pylint: disable=protected-access return _ResourceSummaryWriter( shared_name=shared_name, init_op_fn=functools.partial( gen_summary_ops.create_summary_file_writer, logdir=logdir, max_queue=max_queue, flush_millis=flush_millis, filename_suffix=filename_suffix), name=name)
def generate_dequeue_op(self, tpu_device=0): """Generates the device-side Op to dequeue a tuple from the queue. Implicitly freezes the queue configuration if it is not already frozen, which will raise errors if the shapes and types have not been fully specified. Args: tpu_device: The TPU device ordinal where the infeed instruction should be placed. If None, no explicit placement will be performed, and it is up to the user to call this API from within a proper TPU device scope. The XLA code will fail if the TPU dequeue instruction is not bound to any device. Returns: A list of Outputs corresponding to a shard of infeed dequeued into XLA, suitable for use within a replicated block. Raises: ValueError: if the types or shapes of the tuple elements have not been set; or if a dequeue op has already been generated. """ self.freeze() if self._generated_dequeue_op and not ops.inside_function(): raise ValueError( "Can't generate two dequeue Ops from the same queue") self._generated_dequeue_op = True full_name = "%s/dequeue" % self._name sharded_shapes = [ policy.get_unpartitioned_shape(policy.get_sharded_shape(shape)) for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) ] if tpu_device is not None: with ops.device(tpu_name_util.core(tpu_device)): dequeue_op = tpu_ops.infeed_dequeue_tuple( dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) else: dequeue_op = tpu_ops.infeed_dequeue_tuple(dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) if self._number_of_partitions <= 1: return dequeue_op partitions = [ policy.get_unpartitioned_shape([1] * shape.ndims).as_list() for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) ] return tag_sharding_attribute_for_dequeued_tensors( dequeue_op, partitions)
def is_in_tf_function(): """Returns if inside of a tf.function.""" if not ops.inside_function(): return False # Check if inside Keras FuncGraph. if is_in_keras_graph(): return False # Check for a v1 `wrap_function` FuncGraph. graph = ops.get_default_graph() if (getattr(graph, 'name', False) and graph.name.startswith('wrapped_function')): return False return True
def __init__(self, root): if (not context.executing_eagerly() and not ops.inside_function()): saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary() else: saveables_cache = None super(_AugmentedGraphView, self).__init__(root, saveables_cache) # Object -> (name -> dep) self._extra_dependencies = object_identity.ObjectIdentityDictionary() self._functions = object_identity.ObjectIdentityDictionary() # Cache shared between objects in the same object graph. This is passed to # each trackable object's `_list_extra_dependencies_for_serialization` and # `_list_functions_for_serialization` function. self._serialization_cache = object_identity.ObjectIdentityDictionary()
def _disallow_inside_tf_function(method_name): """Disallow calling a method inside a `tf.function`.""" if ops.inside_function(): error_msg = ( 'Detected a call to `PreprocessingLayer.{method_name}` inside a ' '`tf.function`. `PreprocessingLayer.{method_name} is a high-level ' 'endpoint that manages its own `tf.function`. Please move the call ' 'to `PreprocessingLayer.{method_name}` outside of all enclosing ' '`tf.function`s. Note that you can call a `PreprocessingLayer` ' 'directly on `Tensor`s inside a `tf.function` like: `layer(x)`, ' 'or update its state like: `layer.update_state(x)`.').format( method_name=method_name) raise RuntimeError(error_msg)
def _experimental_distribute_dataset(self, dataset, options): input_workers_devices = self._input_workers_with_options() # If this DistributedDataset is created outside ClusterCoordinator, i,e, # outside a tf.function, we don't build its underlying datasets immediately # until it is passed to ClusterCoordinator.create_per_worker_dataset. return input_util.get_distributed_dataset( dataset, input_workers_devices, self._container_strategy(), num_replicas_in_sync=self._num_replicas_in_sync, options=options, build=ops.inside_function()) # will be built by ClusterCoordinator
def getNext(self, dataset, requires_initialization=False, shared_name=None): """Returns a callable that returns the next element of the dataset. Example use: ```python # In both graph and eager modes dataset = ... get_next = self.getNext(dataset) result = self.evaluate(get_next()) ``` Args: dataset: A dataset whose elements will be returned. requires_initialization: Indicates that when the test is executed in graph mode, it should use an initializable iterator to iterate through the dataset (e.g. when it contains stateful nodes). Defaults to False. shared_name: (Optional.) If non-empty, the returned iterator will be shared under the given name across multiple sessions that share the same devices (e.g. when using a remote server). Returns: A callable that returns the next element of `dataset`. Any `TensorArray` objects `dataset` outputs are stacked. """ def ta_wrapper(gn): def _wrapper(): r = gn() if isinstance(r, tensor_array_ops.TensorArray): return r.stack() else: return r return _wrapper # Create an anonymous iterator if we are in eager-mode or are graph inside # of a tf.function. if context.executing_eagerly() or ops.inside_function(): iterator = iter(dataset) return ta_wrapper(iterator._next_internal) # pylint: disable=protected-access else: if requires_initialization: iterator = dataset_ops.make_initializable_iterator( dataset, shared_name) self.evaluate(iterator.initializer) else: iterator = dataset_ops.make_one_shot_iterator(dataset) get_next = iterator.get_next() return ta_wrapper(lambda: get_next)
def trace_export(name, step=None, profiler_outdir=None): """Stops and exports the active trace as a Summary and/or profile file. Stops the trace and exports all metadata collected during the trace to the default SummaryWriter, if one has been set. Args: name: A name for the summary to be written. step: Explicit `int64`-castable monotonic step value for this summary. If omitted, this defaults to `tf.summary.experimental.get_step()`, which must not be None. profiler_outdir: Output directory for profiler. It is required when profiler is enabled when trace was started. Otherwise, it is ignored. Raises: ValueError: if a default writer exists, but no step was provided and `tf.summary.experimental.get_step()` is None. """ # TODO(stephanlee): See if we can remove profiler_outdir and infer it from # the SummaryWriter's logdir. global _current_trace_context if ops.inside_function(): logging.warn("Cannot export trace inside a tf.function.") return if not context.executing_eagerly(): logging.warn("Can only export trace while executing eagerly.") return with _current_trace_context_lock: if _current_trace_context is None: raise ValueError("Must enable trace before export through " "tf.summary.trace_on.") graph, profiler = _current_trace_context # pylint: disable=redefined-outer-name if profiler and profiler_outdir is None: raise ValueError("Argument `profiler_outdir` is not specified.") run_meta = context.context().export_run_metadata() if graph and not profiler: run_metadata_graphs(name, run_meta, step) else: run_metadata(name, run_meta, step) if profiler: _profiler.save(profiler_outdir, _profiler.stop()) trace_off()
def is_in_tf_function(): """Returns if inside of a tf.function.""" # Check if running in V1 graph mode. if not ops.executing_eagerly_outside_functions(): return False if not ops.inside_function(): return False # Check if inside Keras FuncGraph. if is_in_keras_graph(): return False # Check for a v1 `wrap_function` FuncGraph. graph = ops.get_default_graph() if (getattr(graph, 'name', False) and graph.name.startswith('wrapped_function')): return False return True
def trace_export(name, step=None, profiler_outdir=None): """Stops and exports the active trace as a Summary and/or profile file. Stops the trace and exports all metadata collected during the trace to the default SummaryWriter, if one has been set. Args: name: A name for the summary to be written. step: Explicit `int64`-castable monotonic step value for this summary. If omitted, this defaults to `tf.summary.experimental.get_step()`, which must not be None. profiler_outdir: Output directory for profiler. It is required when profiler is enabled when trace was started. Otherwise, it is ignored. Raises: ValueError: if a default writer exists, but no step was provided and `tf.summary.experimental.get_step()` is None. """ # TODO(stephanlee): See if we can remove profiler_outdir and infer it from # the SummaryWriter's logdir. global _current_trace_context if ops.inside_function(): logging.warn("Cannot export trace inside a tf.function.") return if not context.context().executing_eagerly(): logging.warn("Can only export trace while executing eagerly.") return with _current_trace_context_lock: if _current_trace_context is None: raise ValueError("Must enable trace before export.") graph, profiler = _current_trace_context # pylint: disable=redefined-outer-name if profiler and profiler_outdir is None: raise ValueError("Required profiler_outdir is not specified") run_meta = context.context().export_run_metadata() if graph and not profiler: run_metadata_graphs(name, run_meta, step) else: run_metadata(name, run_meta, step) if profiler: _profiler.save(profiler_outdir, _profiler.stop()) trace_off()
def _save_cached_when_graph_building(self, file_prefix, object_graph_tensor, options, update_ckpt_state=False): """Create or retrieve save ops, overrides parents's private method. Args: file_prefix: The prefix for saved checkpoint files. object_graph_tensor: A `Tensor` to which the current object graph will be fed. options: `CheckpointOptions` object. update_ckpt_state: Optional bool flag. Indiciate whether the internal checkpoint state needs to be updated. This is used for async checkpoint, which DTrackableSaver currently does not support. TODO(chienchunh): Implement async checkpoint for DTrackableSaver. Returns: A two-element tuple with a filename tensor and a feed_dict of tensors to feed when running it (if graph building). The feed dict contains the current object graph and any Python state to be saved in the checkpoint. When executing eagerly only the first argument is meaningful. """ (named_saveable_objects, graph_proto, feed_additions, unused_registered_savers) = self._gather_saveables( object_graph_tensor=object_graph_tensor) if (self._last_save_object_graph != graph_proto # When executing eagerly, we need to re-create SaveableObjects each time # save() is called so they pick up new Tensors passed to their # constructors. That means the Saver needs to be copied with a new # var_list. or context.executing_eagerly() or ops.inside_function()): # This is needed to avoid MultiDeviceSaver creating unnecessary MergeV2 # ops in DTensor. It is an issue when saving TPU Variables on host CPU # mesh given our limited expressiveness in API and hard-coded logic in # broadcasting -- for a small constant Tensor with no extra information, # we place it on the first registered mesh(A.K.A. default mesh). saver = _DSaver(self._mesh, named_saveable_objects) save_op = saver.save(file_prefix, options=options) with ops.device("/cpu:0"): with ops.control_dependencies([save_op]): self._cached_save_operation = array_ops.identity( file_prefix) self._last_save_object_graph = graph_proto return self._cached_save_operation, feed_additions
def save(self, file_prefix): """Saves a training checkpoint and provides basic checkpoint management. The saved checkpoint includes variables created by this object and any trackable objects it depends on at the time `Checkpoint.save()` is called. `save` is a basic convenience wrapper around the `write` method, sequentially numbering checkpoints using `save_counter` and updating the metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint management, for example garbage collection and custom numbering, may be provided by other utilities which also wrap `write` (`tf.train.CheckpointManager` for example). Args: file_prefix: A prefix to use for the checkpoint filenames (/path/to/directory/and_a_prefix). Names are generated based on this prefix and `Checkpoint.save_counter`. Returns: The full path to the checkpoint. """ graph_building = not context.executing_eagerly() if graph_building: if ops.inside_function(): raise NotImplementedError( "Calling tf.train.Checkpoint.save() from a function is not " "supported, as save() modifies saving metadata in ways not " "supported by TensorFlow Operations. Consider using " "tf.train.Checkpoint.write(), a lower-level API which does not " "update metadata. tf.train.latest_checkpoint and related APIs will " "not see this checkpoint.") session = get_session() if self._save_counter is None: # When graph building, if this is a new save counter variable then it # needs to be initialized before assign_add. This is only an issue if # restore() has not been called first. session.run(self.save_counter.initializer) if not graph_building or self._save_assign_op is None: with ops.colocate_with(self.save_counter): assign_op = self.save_counter.assign_add(1, read_value=True) if graph_building: self._save_assign_op = data_structures.NoDependency(assign_op) file_path = self.write(file_prefix) checkpoint_management.update_checkpoint_state_internal( save_dir=os.path.dirname(file_prefix), model_checkpoint_path=file_path, all_model_checkpoint_paths=[file_path], save_relative_paths=True) return file_path
def experimental_run_v2(self, fn, args=(), kwargs=None): """See base class.""" if context.executing_eagerly() and not ops.inside_function(): raise NotImplementedError( "Eager mode not supported in TPUStrategy outside TF functions." ) if kwargs is None: kwargs = {} # Used to re-structure flattened output tensors from `tpu.replicate()` # into a structured format. result = [[]] def replicated_fn(replica_id, replica_args, replica_kwargs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id): result[0] = fn(*replica_args, **replica_kwargs) return result[0] replicate_inputs = [] # By replica. for i in range(self.num_replicas_in_sync): replicate_inputs.append([ constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, args), values.select_replica(i, kwargs) ]) with self.scope(): replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs) # Remove all no ops that may have been added during 'tpu.replicate()' if isinstance(result[0], list): result[0] = [ output for output in result[0] if tensor_util.is_tensor(output) ] # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_output)) for replica_output in replicate_outputs ] device_map = self.extended._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def mark_as_unsaveable(): """Marks the function as unsaveable if not inside save context.""" if ops.inside_function() and not save_context.in_save_context(): ops.get_default_graph().mark_as_unsaveable(""" ConcreteFunction that uses distributed variables in certain way cannot be saved. If you're saving with tf.saved_model.save(..., signatures=f.get_concrete_function()) do @tf.function(input_signature=...) def f_with_input_signature(): ... tf.saved_model.save(..., signatures=f_with_input_signature)` instead.""")
def create_iterator(dataset, job_token): """Creates an iterator for reading from the tf.data service. Args: dataset: A `tf.data.Dataset` object. job_token: A token generated by `create_job`. Returns: A dataset iterator. Raises: RuntimeError: If called outside of a function in graph mode. """ if context.executing_eagerly() or ops.inside_function(): return iterator_ops.OwnedIterator(dataset, job_token=job_token) else: raise RuntimeError("create_iterator() is only supported inside of " "tf.function or when eager execution is enabled.")
def __init__(self, type_, repr_, stack_frame, error_in_function, warn_in_eager): self._type = type_ self._repr = repr_ self._stack_frame = stack_frame self._error_in_function = error_in_function if context.executing_eagerly(): # If warn_in_eager, sated == False. Otherwise true. self._sated = not warn_in_eager elif ops.inside_function(): if error_in_function: self._sated = False ops.add_exit_callback_to_default_func_graph( lambda: self._check_sated(raise_error=True)) else: self._sated = True else: # TF1 graph building mode self._sated = False
def experimental_run_v2(self, fn, args=(), kwargs=None): """See base class.""" if context.executing_eagerly() and not ops.inside_function(): raise NotImplementedError( "Eager mode not supported in TPUStrategy outside TF functions.") if kwargs is None: kwargs = {} # Used to re-structure flattened output tensors from `tpu.replicate()` # into a structured format. result = [[]] def replicated_fn(replica_id, replica_args, replica_kwargs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id): result[0] = fn(*replica_args, **replica_kwargs) return result[0] replicate_inputs = [] # By replica. for i in range(self.num_replicas_in_sync): replicate_inputs.append( [constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, args), values.select_replica(i, kwargs)]) with self.scope(): replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs) # Remove all no ops that may have been added during 'tpu.replicate()' if isinstance(result[0], list): result[0] = [ output for output in result[0] if tensor_util.is_tensor(output) ] # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_output)) for replica_output in replicate_outputs ] device_map = self.extended._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def trace_on(graph=True, profiler=False): # pylint: disable=redefined-outer-name """Starts a trace to record computation graphs and profiling information. Must be invoked in eager mode. When enabled, TensorFlow runtime will collection information that can later be exported and consumed by TensorBoard. The trace is activated across the entire TensorFlow runtime and affects all threads of execution. To stop the trace and export the collected information, use `tf.summary.trace_export`. To stop the trace without exporting, use `tf.summary.trace_off`. Args: graph: If True, enables collection of executed graphs. It includes ones from tf.function invocation and ones from the legacy graph mode. The default is True. profiler: If True, enables the advanced profiler. Enabling profiler implicitly enables the graph collection. The profiler may incur a high memory overhead. The default is False. """ if ops.inside_function(): logging.warn("Cannot enable trace inside a tf.function.") return if not context.context().executing_eagerly(): logging.warn("Must enable trace in eager mode.") return global _current_trace_context with _current_trace_context_lock: if _current_trace_context: logging.warn("Trace already enabled") return if graph and not profiler: context.context().enable_graph_collection() if profiler: context.context().enable_run_metadata() _profiler.start() _current_trace_context = _TraceContext(graph=graph, profiler=profiler)
def trace_on(graph=True, profiler=False): # pylint: disable=redefined-outer-name """Starts a trace to record computation graphs and profiling information. Must be invoked in eager mode. When enabled, TensorFlow runtime will collection information that can later be exported and consumed by TensorBoard. The trace is activated across the entire TensorFlow runtime and affects all threads of execution. To stop the trace and export the collected information, use `tf.summary.trace_export`. To stop the trace without exporting, use `tf.summary.trace_off`. Args: graph: If True, enables collection of executed graphs. It includes ones from tf.function invocation and ones from the legacy graph mode. The default is True. profiler: If True, enables the advanced profiler. Enabling profiler implicitly enables the graph collection. The profiler may incur a high memory overhead. The default is False. """ if ops.inside_function(): logging.warn("Cannot enable trace inside a tf.function.") return if not context.executing_eagerly(): logging.warn("Must enable trace in eager mode.") return global _current_trace_context with _current_trace_context_lock: if _current_trace_context: logging.warn("Trace already enabled") return if graph and not profiler: context.context().enable_graph_collection() if profiler: context.context().enable_run_metadata() _profiler.start() _current_trace_context = _TraceContext(graph=graph, profiler=profiler)
def experimental_run(self, fn, input_iterator=None): """See base class.""" if context.executing_eagerly() and not ops.inside_function(): raise NotImplementedError( "Eager mode not supported in TPUStrategy outside TF functions." ) if input_iterator is None: inputs = [] else: inputs = input_iterator.get_next() result = [None] def replicated_fn(replica_id, replica_input): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id): if input_iterator is None: result[0] = fn() else: result[0] = fn(replica_input) return result[0] replicate_inputs = [] # By replica. for i in range(self.num_replicas_in_sync): replicate_inputs.append([ constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, inputs) ]) with self.scope(): replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs) # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_outputs)) for replica_outputs in replicate_outputs ] device_map = self.extended._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def _distribute_datasets_from_function(self, dataset_fn, options): # There is no synchronization beyond a worker and thus, the number of # input pipelines in sync is only 1 per worker. input_pipeline_id_in_sync = 0 num_input_pipelines_in_sync = 1 input_context = distribute_lib.InputContext( num_input_pipelines=num_input_pipelines_in_sync, input_pipeline_id=input_pipeline_id_in_sync, num_replicas_in_sync=self._num_replicas_in_sync) # If this DistributedDatasetFromFunction is created outside # ClusterCoordinator, i,e, outside a tf.function, we don't build its # underlying datasets immediately until it is passed to # ClusterCoordinator.create_per_worker_dataset. return input_util.get_distributed_datasets_from_function( dataset_fn, self._input_workers_with_options(options), [input_context], self._container_strategy(), options=options, build=ops.inside_function()) # will be built by ClusterCoordinator
def get_session(op_input_list=()): """Returns the session object for the current thread.""" _SESSION = keras.backend._SESSION # noqa:N806 default_session = ops.get_default_session() if default_session is not None: session = default_session else: if ops.inside_function(): raise RuntimeError( 'Cannot get session inside Tensorflow graph function.') # If we don't have a session, or that session does not match the current # graph, create and cache a new session. if getattr(_SESSION, 'session', None) is None: if getattr(keras.backend, 'READY_FOR_AUTODIST', False): _SESSION.session = autodist.autodist.get_default_autodist( ).create_distributed_session() else: _SESSION.session = session_module.Session( config=get_default_session_config()) session = _SESSION.session return session
def _add_should_use_warning(x, error_in_function=False, warn_in_eager=False): """Wraps object x so that if it is never used, a warning is logged. Args: x: Python object. error_in_function: Python bool. If `True`, a `RuntimeError` is raised if the returned value is never used when created during `tf.function` tracing. warn_in_eager: Python bool. If `True` raise warning if in Eager mode as well as graph mode. Returns: An instance of `TFShouldUseWarningWrapper` which subclasses `type(x)` and is a very shallow wrapper for `x` which logs access into `x`. """ if x is None or (isinstance(x, list) and not x): return x if context.executing_eagerly() and not warn_in_eager: return x if ops.inside_function() and not error_in_function: # We don't currently log warnings in tf.function calls, so just skip it. return x # Extract the current frame for later use by traceback printing. try: raise ValueError() except ValueError: stack_frame = sys.exc_info()[2].tb_frame.f_back tf_should_use_helper = _TFShouldUseHelper( type_=type(x), repr_=repr(x), stack_frame=stack_frame, error_in_function=error_in_function, warn_in_eager=warn_in_eager) return _get_wrapper(x, tf_should_use_helper)
def experimental_run(self, fn, input_iterator=None): """See base class.""" if context.executing_eagerly() and not ops.inside_function(): raise NotImplementedError( "Eager mode not supported in TPUStrategy outside TF functions.") if input_iterator is None: inputs = [] else: inputs = input_iterator.get_next() result = [None] def replicated_fn(replica_id, replica_input): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id): if input_iterator is None: result[0] = fn() else: result[0] = fn(replica_input) return result[0] replicate_inputs = [] # By replica. for i in range(self.num_replicas_in_sync): replicate_inputs.append( [constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, inputs)]) with self.scope(): replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs) # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_outputs)) for replica_outputs in replicate_outputs] device_map = self.extended._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def __init__(self, initial_value=None, trainable=None, caching_device=None, name=None, dtype=None, constraint=None, add_initializers_to=None, lifted_initializer_graph=None, synchronization=None, aggregation=None, shape=None, **unused_kwargs): """Creates a variable. Args: initial_value: A `Tensor`, or Python object convertible to a `Tensor`, which is the initial value for the Variable. The initial value must have a shape specified unless `validate_shape` is set to False. Can also be a callable with no argument that returns the initial value when called. (Note that initializer functions from init_ops.py must first be bound to a shape before being used here.) trainable: If `True`, GradientTapes automatically watch uses of this Variable. caching_device: Optional device string or function describing where the Variable should be cached for reading. Defaults to the Variable's device. If not `None`, caches on another device. Typical use is to cache on the device where the Ops using the Variable reside, to deduplicate copying through `Switch` and other conditional statements. name: Optional name for the variable. Defaults to `'Variable'` and gets uniquified automatically. dtype: If set, initial_value will be converted to the given type. If None, either the datatype will be kept (if initial_value is a Tensor) or float32 will be used (if it is a Python object convertible to a Tensor). constraint: An optional projection function to be applied to the variable after being updated by an `Optimizer` (e.g. used to implement norm constraints or value constraints for layer weights). The function must take as input the unprojected Tensor representing the value of the variable and return the Tensor for the projected value (which must have the same shape). Constraints are not safe to use when doing asynchronous distributed training. add_initializers_to: if not None and not in legacy graph mode, the initializer tensor will be added to this map in addition to adding the assignment to the function. lifted_initializer_graph: FuncGraph to try to lift initializers to. synchronization: Indicates when a distributed a variable will be aggregated. Accepted values are constants defined in the class `tf.VariableSynchronization`. By default the synchronization is set to `AUTO` and the current `DistributionStrategy` chooses when to synchronize. If `synchronization` is set to `ON_READ`, `trainable` must not be set to `True`. aggregation: Indicates how a distributed variable will be aggregated. Accepted values are constants defined in the class `tf.VariableAggregation`. shape: (optional) The shape of this variable. If None, the shape of `initial_value` will be used. When setting this argument to `tf.TensorShape(None)` (representing an unspecified shape), the variable can be assigned with values of different shapes. Raises: ValueError: If the initial value is not specified, or does not have a shape and `validate_shape` is `True`. RuntimeError: If called outside of a function definition. """ if not ops.inside_function(): # If we've been init_scope()d out of the function definition nothing to do # here; we can't really do the capturing or conditional logic. resource_variable_ops.ResourceVariable.__init__( self, initial_value=initial_value, trainable=trainable, caching_device=caching_device, name=name, dtype=dtype, constraint=constraint) return if initial_value is None: raise ValueError("initial_value must be specified.") init_from_fn = callable(initial_value) if constraint is not None and not callable(constraint): raise ValueError("The `constraint` argument must be a callable.") if isinstance(initial_value, trackable.CheckpointInitialValue): self._maybe_initialize_trackable() self._update_uid = initial_value.checkpoint_position.restore_uid initial_value = initial_value.wrapped_value with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as name: with ops.name_scope("Initializer"), ops.device(None): initial_value = ops.convert_to_tensor( initial_value() if init_from_fn else initial_value, name="initial_value", dtype=dtype) assert initial_value is not None # Don't use `shape or initial_value.shape` since TensorShape has # overridden `__bool__`. if shape is None: shape = initial_value.shape # Use the constructor for UninitializedVariable to start. super(UnliftedInitializerVariable, self).__init__( trainable=trainable, caching_device=caching_device, name=name, shape=shape, dtype=initial_value.dtype, constraint=constraint, synchronization=synchronization, aggregation=aggregation, extra_handle_data=initial_value, **unused_kwargs) if self._in_graph_mode: with ops.init_scope(): outer_graph = ops.get_default_graph() func_graph = ops.get_default_graph() function_placeholders = ( func_graph.inputs + func_graph.internal_captures) placeholder_ops = set( [tensor.op for tensor in function_placeholders]) lifted_initializer = lift_to_graph.lift_to_graph( [initial_value], outer_graph, disallowed_placeholders=placeholder_ops)[initial_value] with ops.init_scope(): self._initial_value = lifted_initializer with ops.name_scope("IsInitialized"): self._is_initialized_op = ( resource_variable_ops.var_is_initialized_op(self._handle)) if initial_value is not None: with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): self._initializer_op = resource_variable_ops.assign_variable_op( self._handle, lifted_initializer, name=n) else: if add_initializers_to is not None: add_initializers_to[self] = initial_value def assign_fn(): with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): resource_variable_ops.assign_variable_op( self._handle, initial_value, name=n) # Returning values to keep tf.cond happy. return ops.convert_to_tensor(1) def not_assign_fn(): return ops.convert_to_tensor(0) # Note: this cond is always guaranteed to run because we're inside a # defun which will insert automatic control dependencies. control_flow_ops.cond( resource_variable_ops.var_is_initialized_op(self._handle), not_assign_fn, assign_fn)
def __init__(self, initial_value=None, trainable=None, caching_device=None, name=None, dtype=None, constraint=None, add_initializers_to=None, lifted_initializer_graph=None, synchronization=None, aggregation=None, **unused_kwargs): """Creates a variable. Args: initial_value: A `Tensor`, or Python object convertible to a `Tensor`, which is the initial value for the Variable. The initial value must have a shape specified unless `validate_shape` is set to False. Can also be a callable with no argument that returns the initial value when called. (Note that initializer functions from init_ops.py must first be bound to a shape before being used here.) trainable: If `True`, GradientTapes automatically watch uses of this Variable. caching_device: Optional device string or function describing where the Variable should be cached for reading. Defaults to the Variable's device. If not `None`, caches on another device. Typical use is to cache on the device where the Ops using the Variable reside, to deduplicate copying through `Switch` and other conditional statements. name: Optional name for the variable. Defaults to `'Variable'` and gets uniquified automatically. dtype: If set, initial_value will be converted to the given type. If None, either the datatype will be kept (if initial_value is a Tensor) or float32 will be used (if it is a Python object convertible to a Tensor). constraint: An optional projection function to be applied to the variable after being updated by an `Optimizer` (e.g. used to implement norm constraints or value constraints for layer weights). The function must take as input the unprojected Tensor representing the value of the variable and return the Tensor for the projected value (which must have the same shape). Constraints are not safe to use when doing asynchronous distributed training. add_initializers_to: if not None and not in legacy graph mode, the initializer tensor will be added to this map in addition to adding the assignment to the function. lifted_initializer_graph: FuncGraph to try to lift initializers to. synchronization: Indicates when a distributed a variable will be aggregated. Accepted values are constants defined in the class `tf.VariableSynchronization`. By default the synchronization is set to `AUTO` and the current `DistributionStrategy` chooses when to synchronize. If `synchronization` is set to `ON_READ`, `trainable` must not be set to `True`. aggregation: Indicates how a distributed variable will be aggregated. Accepted values are constants defined in the class `tf.VariableAggregation`. Raises: ValueError: If the initial value is not specified, or does not have a shape and `validate_shape` is `True`. RuntimeError: If called outside of a function definition. """ if not ops.inside_function(): # If we've been init_scope()d out of the function definition nothing to do # here; we can't really do the capturing or conditional logic. resource_variable_ops.ResourceVariable.__init__( self, initial_value=initial_value, trainable=trainable, caching_device=caching_device, name=name, dtype=dtype, constraint=constraint) return if initial_value is None: raise ValueError("initial_value must be specified.") init_from_fn = callable(initial_value) if constraint is not None and not callable(constraint): raise ValueError("The `constraint` argument must be a callable.") if isinstance(initial_value, trackable.CheckpointInitialValue): self._maybe_initialize_trackable() self._update_uid = initial_value.checkpoint_position.restore_uid initial_value = initial_value.wrapped_value with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as name: with ops.name_scope("Initializer"), ops.device(None): initial_value = ops.convert_to_tensor( initial_value() if init_from_fn else initial_value, name="initial_value", dtype=dtype) assert initial_value is not None # Use the constructor for UninitializedVariable to start. super(UnliftedInitializerVariable, self).__init__( trainable=trainable, caching_device=caching_device, name=name, shape=initial_value.shape, dtype=initial_value.dtype, constraint=constraint, synchronization=synchronization, aggregation=aggregation, extra_handle_data=initial_value, **unused_kwargs) if self._in_graph_mode: with ops.init_scope(): outer_graph = ops.get_default_graph() func_graph = ops.get_default_graph() function_placeholders = ( func_graph.inputs + func_graph.internal_captures) placeholder_ops = set( [tensor.op for tensor in function_placeholders]) lifted_initializer = lift_to_graph.lift_to_graph( [initial_value], outer_graph, disallowed_placeholders=placeholder_ops)[initial_value] with ops.init_scope(): self._initial_value = lifted_initializer with ops.name_scope("IsInitialized"): self._is_initialized_op = ( resource_variable_ops.var_is_initialized_op(self._handle)) if initial_value is not None: with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): self._initializer_op = resource_variable_ops.assign_variable_op( self._handle, lifted_initializer, name=n) else: if add_initializers_to is not None: add_initializers_to[self] = initial_value def assign_fn(): with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): resource_variable_ops.assign_variable_op( self._handle, initial_value, name=n) # Returning values to keep tf.cond happy. return ops.convert_to_tensor(1) def not_assign_fn(): return ops.convert_to_tensor(0) # Note: this cond is always guaranteed to run because we're inside a # defun which will insert automatic control dependencies. control_flow_ops.cond( resource_variable_ops.var_is_initialized_op(self._handle), not_assign_fn, assign_fn)
def _tpu_run(strategy, fn, args, kwargs): """Common implementation of TPUStrategy.experimental_run_v2.""" if context.executing_eagerly() and not ops.inside_function(): raise NotImplementedError( "Eager mode not supported in TPUStrategy outside TF functions.") if kwargs is None: kwargs = {} # Used to re-structure flattened output tensors from `tpu.replicate()` # into a structured format. result = [[]] def replicated_fn(replica_id, replica_args, replica_kwargs): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id): result[0] = fn(*replica_args, **replica_kwargs) return result[0] replicate_inputs = [] # By replica. for i in range(strategy.num_replicas_in_sync): replicate_inputs.append([ constant_op.constant(i, dtype=dtypes.int32), values.select_replica(i, args), values.select_replica(i, kwargs) ]) # Construct and pass `maximum_shapes` so that we could support dynamic # shapes using dynamic padder. if replicate_inputs: maximum_shapes = [] flattened_list = nest.flatten(replicate_inputs[0]) for input_tensor in flattened_list: maximum_shape = input_tensor.get_shape() if tensor_util.is_tensor( input_tensor) else tensor_shape.TensorShape( np.shape(input_tensor)) maximum_shapes.append(maximum_shape) maximum_shapes = nest.pack_sequence_as(replicate_inputs[0], maximum_shapes) else: maximum_shapes = None with strategy.scope(): replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs, maximum_shapes=maximum_shapes) # Remove all no ops that may have been added during 'tpu.replicate()' if isinstance(result[0], list): result[0] = [ output for output in result[0] if tensor_util.is_tensor(output) ] # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. replicate_outputs = [ nest.pack_sequence_as(result[0], nest.flatten(replica_output)) for replica_output in replicate_outputs ] device_map = strategy.extended._device_map # pylint: disable=protected-access return values.regroup(device_map, replicate_outputs)
def save(obj, export_dir, signatures=None): # pylint: disable=line-too-long """Exports the Trackable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md). Example usage: ```python class Adder(tf.train.Checkpoint): @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) def add(self, x): return x + x + 1. to_export = Adder() tf.saved_model.save(to_export, '/tmp/adder') ``` The resulting SavedModel is then servable with an input named "x", its value having any shape and dtype float32. The optional `signatures` argument controls which methods in `obj` will be available to programs which consume `SavedModel`s, for example serving APIs. Python functions may be decorated with `@tf.function(input_signature=...)` and passed as signatures directly, or lazily with a call to `get_concrete_function` on the method decorated with `@tf.function`. If the `signatures` argument is omitted, `obj` will be searched for `@tf.function`-decorated methods. If exactly one `@tf.function` is found, that method will be used as the default signature for the SavedModel. This behavior is expected to change in the future, when a corresponding `tf.saved_model.load` symbol is added. At that point signatures will be completely optional, and any `@tf.function` attached to `obj` or its dependencies will be exported for use with `load`. When invoking a signature in an exported SavedModel, `Tensor` arguments are identified by name. These names will come from the Python function's argument names by default. They may be overridden by specifying a `name=...` argument in the corresponding `tf.TensorSpec` object. Explicit naming is required if multiple `Tensor`s are passed through a single argument to the Python function. The outputs of functions used as `signatures` must either be flat lists, in which case outputs will be numbered, or a dictionary mapping string keys to `Tensor`, in which case the keys will be used to name outputs. Signatures are available in objects returned by `tf.saved_model.load` as a `.signatures` attribute. This is a reserved attribute: `tf.saved_model.save` on an object with a custom `.signatures` attribute will raise an exception. Since `tf.keras.Model` objects are also Trackable, this function can be used to export Keras models. For example, exporting with a signature specified: ```python class Model(tf.keras.Model): @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)]) def serve(self, serialized): ... m = Model() tf.saved_model.save(m, '/tmp/saved_model/') ``` Exporting from a function without a fixed signature: ```python class Model(tf.keras.Model): @tf.function def call(self, x): ... m = Model() tf.saved_model.save( m, '/tmp/saved_model/', signatures=m.call.get_concrete_function( tf.TensorSpec(shape=[None, 3], dtype=tf.float32, name="inp"))) ``` `tf.keras.Model` instances constructed from inputs and outputs already have a signature and so do not require a `@tf.function` decorator or a `signatures` argument. If neither are specified, the model's forward pass is exported. ```python x = input_layer.Input((4,), name="x") y = core.Dense(5, name="out")(x) model = training.Model(x, y) tf.saved_model.save(model, '/tmp/saved_model/') # The exported SavedModel takes "x" with shape [None, 4] and returns "out" # with shape [None, 5] ``` Variables must be tracked by assigning them to an attribute of a tracked object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers from `tf.keras.layers`, optimizers from `tf.train`) track their variables automatically. This is the same tracking scheme that `tf.train.Checkpoint` uses, and an exported `Checkpoint` object may be restored as a training checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's "variables/" subdirectory. Currently variables are the only stateful objects supported by `tf.saved_model.save`, but others (e.g. tables) will be supported in the future. `tf.function` does not hard-code device annotations from outside the function body, instead using the calling context's device. This means for example that exporting a model which runs on a GPU and serving it on a CPU will generally work, with some exceptions. `tf.device` annotations inside the body of the function will be hard-coded in the exported model; this type of annotation is discouraged. Device-specific operations, e.g. with "cuDNN" in the name or with device-specific layouts, may cause issues. Currently a `DistributionStrategy` is another exception: active distribution strategies will cause device placements to be hard-coded in a function. Exporting a single-device computation and importing under a `DistributionStrategy` is not currently supported, but may be in the future. SavedModels exported with `tf.saved_model.save` [strip default-valued attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes) automatically, which removes one source of incompatibilities when the consumer of a SavedModel is running an older TensorFlow version than the producer. There are however other sources of incompatibilities which are not handled automatically, such as when the exported model contains operations which the consumer does not have definitions for. The current implementation of `tf.saved_model.save` targets serving use-cases, but omits information which will be necessary for the planned future implementation of `tf.saved_model.load`. Exported models using the current `save` implementation, and other existing SavedModels, will not be compatible with `tf.saved_model.load` when it is implemented. Further, `save` will in the future attempt to export `@tf.function`-decorated methods which it does not currently inspect, so some objects which are exportable today will raise exceptions on export in the future (e.g. due to complex/non-serializable default arguments). Such backwards-incompatible API changes are expected only prior to the TensorFlow 2.0 release. Args: obj: A trackable object to export. export_dir: A directory in which to write the SavedModel. signatures: Optional, either a `tf.function` with an input signature specified or the result of `f.get_concrete_function` on a `@tf.function`-decorated function `f`, in which case `f` will be used to generate a signature for the SavedModel under the default serving signature key. `signatures` may also be a dictionary, in which case it maps from signature keys to either `tf.function` instances with input signatures or concrete functions. The keys of such a dictionary may be arbitrary strings, but will typically be from the `tf.saved_model.signature_constants` module. Raises: ValueError: If `obj` is not trackable. @compatibility(eager) Not well supported when graph building. From TensorFlow 1.x, `tf.enable_eager_execution()` should run first. Calling tf.saved_model.save in a loop when graph building from TensorFlow 1.x will add new save operations to the default graph each iteration. May not be called from within a function body. @end_compatibility """ if ops.inside_function(): raise AssertionError( "tf.saved_model.save is not supported inside a traced " "@tf.function. Move the call to the outer eagerly-executed " "context.") # pylint: enable=line-too-long if not isinstance(obj, base.Trackable): raise ValueError( "Expected a Trackable object for export, got {}.".format(obj)) checkpoint_graph_view = _AugmentedGraphView(obj) if signatures is None: signatures = signature_serialization.find_function_to_export( checkpoint_graph_view) signatures = signature_serialization.canonicalize_signatures(signatures) signature_serialization.validate_saveable_view(checkpoint_graph_view) signature_map = signature_serialization.create_signature_map(signatures) checkpoint_graph_view.add_object( parent_node=checkpoint_graph_view.root, name_in_parent=signature_serialization.SIGNATURE_ATTRIBUTE_NAME, subgraph_root=signature_map) # Use _SaveableView to provide a frozen listing of properties and functions. # Note we run this twice since, while constructing the view the first time # there can be side effects of creating variables. _ = _SaveableView(checkpoint_graph_view) saveable_view = _SaveableView(checkpoint_graph_view) # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x # compatible (no sessions) and share it with this export API rather than # making a SavedModel proto and writing it directly. saved_model = saved_model_pb2.SavedModel() meta_graph_def = saved_model.meta_graphs.add() object_saver = util.TrackableSaver(checkpoint_graph_view) asset_info, exported_graph = _fill_meta_graph_def( meta_graph_def, saveable_view, signatures) saved_model.saved_model_schema_version = ( constants.SAVED_MODEL_SCHEMA_VERSION) # So far we've just been generating protocol buffers with no I/O. Now we write # the checkpoint, copy assets into the assets directory, and write out the # SavedModel proto itself. utils_impl.get_or_create_variables_dir(export_dir) object_saver.save(utils_impl.get_variables_path(export_dir)) builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir) path = os.path.join( compat.as_bytes(export_dir), compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) object_graph_proto = _serialize_object_graph( saveable_view, asset_info.asset_index) meta_graph_def.object_graph_def.CopyFrom(object_graph_proto) file_io.write_string_to_file(path, saved_model.SerializeToString()) # Clean reference cycles so repeated export()s don't make work for the garbage # collector. Before this point we need to keep references to captured # constants in the saved graph. ops.dismantle_graph(exported_graph)
def __init__( self, # pylint: disable=super-init-not-called initial_value=None, trainable=None, caching_device=None, name=None, dtype=None, constraint=None, add_initializers_to=None, lifted_initializer_graph=None, synchronization=None, aggregation=None, **unused_kwargs): """Creates a variable. Args: initial_value: A `Tensor`, or Python object convertible to a `Tensor`, which is the initial value for the Variable. The initial value must have a shape specified unless `validate_shape` is set to False. Can also be a callable with no argument that returns the initial value when called. (Note that initializer functions from init_ops.py must first be bound to a shape before being used here.) trainable: If `True`, GradientTapes automatically watch uses of this Variable. caching_device: Optional device string or function describing where the Variable should be cached for reading. Defaults to the Variable's device. If not `None`, caches on another device. Typical use is to cache on the device where the Ops using the Variable reside, to deduplicate copying through `Switch` and other conditional statements. name: Optional name for the variable. Defaults to `'Variable'` and gets uniquified automatically. dtype: If set, initial_value will be converted to the given type. If None, either the datatype will be kept (if initial_value is a Tensor) or float32 will be used (if it is a Python object convertible to a Tensor). constraint: An optional projection function to be applied to the variable after being updated by an `Optimizer` (e.g. used to implement norm constraints or value constraints for layer weights). The function must take as input the unprojected Tensor representing the value of the variable and return the Tensor for the projected value (which must have the same shape). Constraints are not safe to use when doing asynchronous distributed training. add_initializers_to: if not None and not in legacy graph mode, the initializer tensor will be added to this map in addition to adding the assignment to the function. lifted_initializer_graph: FuncGraph to try to lift initializers to. synchronization: Indicates when a distributed a variable will be aggregated. Accepted values are constants defined in the class `tf.VariableSynchronization`. By default the synchronization is set to `AUTO` and the current `DistributionStrategy` chooses when to synchronize. If `synchronization` is set to `ON_READ`, `trainable` must not be set to `True`. aggregation: Indicates how a distributed variable will be aggregated. Accepted values are constants defined in the class `tf.VariableAggregation`. Raises: ValueError: If the initial value is not specified, or does not have a shape and `validate_shape` is `True`. RuntimeError: If called outside of a function definition. """ if not ops.inside_function(): # If we've been init_scope()d out of the function definition nothing to do # here; we can't really do the capturing or conditional logic. resource_variable_ops.ResourceVariable.__init__( self, initial_value=initial_value, trainable=trainable, caching_device=caching_device, name=name, dtype=dtype, constraint=constraint) return with ops.init_scope(): self._in_graph_mode = not context.executing_eagerly() if initial_value is None: raise ValueError("initial_value must be specified.") init_from_fn = callable(initial_value) if constraint is not None and not callable(constraint): raise ValueError("The `constraint` argument must be a callable.") if isinstance(initial_value, trackable.CheckpointInitialValue): self._maybe_initialize_trackable() self._update_uid = initial_value.checkpoint_position.restore_uid initial_value = initial_value.wrapped_value synchronization, aggregation, trainable = ( variables.validate_synchronization_aggregation_trainable( synchronization, aggregation, trainable, name)) self._trainable = trainable self._synchronization = synchronization self._aggregation = aggregation self._save_slice_info = None self._initial_value = None self._initializer_op = None self._is_initialized_op = None self._graph_element = None self._cached_value = None # Store the graph key so optimizers know how to only retrieve variables from # this graph. Guaranteed to be the same as the eager graph_key. self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as name: # pylint: disable=protected-access with ops.init_scope(): handle_name = ops.name_from_scope_name(name) unique_id = "%s_%d" % (handle_name, ops.uid()) shared_name = context.shared_name(unique_id) with ops.name_scope("Initializer"), ops.device(None): initial_value = ops.convert_to_tensor( initial_value() if init_from_fn else initial_value, name="initial_value", dtype=dtype) with ops.init_scope(): self._handle = resource_variable_ops.eager_safe_variable_handle( initial_value=initial_value, shared_name=shared_name, name=name, graph_mode=self._in_graph_mode) self._shape = initial_value.shape self._unique_id = unique_id self._handle_name = handle_name + ":0" self._dtype = initial_value.dtype.base_dtype self._constraint = constraint assert initial_value is not None if self._in_graph_mode: with ops.init_scope(): outer_graph = ops.get_default_graph() func_graph = ops.get_default_graph() function_placeholders = (func_graph.inputs + func_graph.internal_captures) placeholder_ops = set( [tensor.op for tensor in function_placeholders]) lifted_initializer = lift_to_graph.lift_to_graph( [initial_value], outer_graph, disallowed_placeholders=placeholder_ops)[initial_value] with ops.init_scope(): self._initial_value = lifted_initializer with ops.name_scope("IsInitialized"): self._is_initialized_op = ( resource_variable_ops.var_is_initialized_op( self._handle)) if initial_value is not None: with ops.name_scope("Assign") as n, ops.colocate_with( self._handle): self._initializer_op = resource_variable_ops.assign_variable_op( self._handle, lifted_initializer, name=n) with ops.name_scope("Read"), ops.colocate_with( self._handle): # Manually assign reads to the handle's device to avoid log # messages. with ops.device(self._handle.device): value = self._read_variable_op() self._graph_element = value ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, self) else: if add_initializers_to is not None: add_initializers_to[self] = initial_value def assign_fn(): with ops.name_scope("Assign") as n, ops.colocate_with( self._handle): resource_variable_ops.assign_variable_op(self._handle, initial_value, name=n) # Returning values to keep tf.cond happy. return ops.convert_to_tensor(1) def not_assign_fn(): return ops.convert_to_tensor(0) # Note: this cond is always guaranteed to run because we're inside a # defun which will insert automatic control dependencies. control_flow_ops.cond( resource_variable_ops.var_is_initialized_op(self._handle), not_assign_fn, assign_fn) # After the handle has been created, set up a way to clean it up when # executing eagerly. We'll hold the only reference to the deleter, so that # when this object is garbage collected the deleter will be too. This # means ResourceVariables can be part of reference cycles without those # cycles being uncollectable. if not self._in_graph_mode: self._handle_deleter = resource_variable_ops.EagerResourceDeleter( handle=self._handle, handle_device=self._handle.device) self._cached_shape_as_list = None
def save(obj, export_dir, signatures=None, options=None): # pylint: disable=line-too-long """Exports the Trackable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md). Example usage: ```python class Adder(tf.Module): @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) def add(self, x): return x + x + 1. to_export = Adder() tf.saved_model.save(to_export, '/tmp/adder') ``` The resulting SavedModel is then servable with an input named "x", its value having any shape and dtype float32. The optional `signatures` argument controls which methods in `obj` will be available to programs which consume `SavedModel`s, for example serving APIs. Python functions may be decorated with `@tf.function(input_signature=...)` and passed as signatures directly, or lazily with a call to `get_concrete_function` on the method decorated with `@tf.function`. If the `signatures` argument is omitted, `obj` will be searched for `@tf.function`-decorated methods. If exactly one `@tf.function` is found, that method will be used as the default signature for the SavedModel. This behavior is expected to change in the future, when a corresponding `tf.saved_model.load` symbol is added. At that point signatures will be completely optional, and any `@tf.function` attached to `obj` or its dependencies will be exported for use with `load`. When invoking a signature in an exported SavedModel, `Tensor` arguments are identified by name. These names will come from the Python function's argument names by default. They may be overridden by specifying a `name=...` argument in the corresponding `tf.TensorSpec` object. Explicit naming is required if multiple `Tensor`s are passed through a single argument to the Python function. The outputs of functions used as `signatures` must either be flat lists, in which case outputs will be numbered, or a dictionary mapping string keys to `Tensor`, in which case the keys will be used to name outputs. Signatures are available in objects returned by `tf.saved_model.load` as a `.signatures` attribute. This is a reserved attribute: `tf.saved_model.save` on an object with a custom `.signatures` attribute will raise an exception. Since `tf.keras.Model` objects are also Trackable, this function can be used to export Keras models. For example, exporting with a signature specified: ```python class Model(tf.keras.Model): @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)]) def serve(self, serialized): ... m = Model() tf.saved_model.save(m, '/tmp/saved_model/') ``` Exporting from a function without a fixed signature: ```python class Model(tf.keras.Model): @tf.function def call(self, x): ... m = Model() tf.saved_model.save( m, '/tmp/saved_model/', signatures=m.call.get_concrete_function( tf.TensorSpec(shape=[None, 3], dtype=tf.float32, name="inp"))) ``` `tf.keras.Model` instances constructed from inputs and outputs already have a signature and so do not require a `@tf.function` decorator or a `signatures` argument. If neither are specified, the model's forward pass is exported. ```python x = input_layer.Input((4,), name="x") y = core.Dense(5, name="out")(x) model = training.Model(x, y) tf.saved_model.save(model, '/tmp/saved_model/') # The exported SavedModel takes "x" with shape [None, 4] and returns "out" # with shape [None, 5] ``` Variables must be tracked by assigning them to an attribute of a tracked object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers from `tf.keras.layers`, optimizers from `tf.train`) track their variables automatically. This is the same tracking scheme that `tf.train.Checkpoint` uses, and an exported `Checkpoint` object may be restored as a training checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's "variables/" subdirectory. Currently variables are the only stateful objects supported by `tf.saved_model.save`, but others (e.g. tables) will be supported in the future. `tf.function` does not hard-code device annotations from outside the function body, instead using the calling context's device. This means for example that exporting a model which runs on a GPU and serving it on a CPU will generally work, with some exceptions. `tf.device` annotations inside the body of the function will be hard-coded in the exported model; this type of annotation is discouraged. Device-specific operations, e.g. with "cuDNN" in the name or with device-specific layouts, may cause issues. Currently a `DistributionStrategy` is another exception: active distribution strategies will cause device placements to be hard-coded in a function. Exporting a single-device computation and importing under a `DistributionStrategy` is not currently supported, but may be in the future. SavedModels exported with `tf.saved_model.save` [strip default-valued attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes) automatically, which removes one source of incompatibilities when the consumer of a SavedModel is running an older TensorFlow version than the producer. There are however other sources of incompatibilities which are not handled automatically, such as when the exported model contains operations which the consumer does not have definitions for. Args: obj: A trackable object to export. export_dir: A directory in which to write the SavedModel. signatures: Optional, either a `tf.function` with an input signature specified or the result of `f.get_concrete_function` on a `@tf.function`-decorated function `f`, in which case `f` will be used to generate a signature for the SavedModel under the default serving signature key. `signatures` may also be a dictionary, in which case it maps from signature keys to either `tf.function` instances with input signatures or concrete functions. The keys of such a dictionary may be arbitrary strings, but will typically be from the `tf.saved_model.signature_constants` module. options: Optional, `tf.saved_model.SaveOptions` object that specifies options for saving. Raises: ValueError: If `obj` is not trackable. @compatibility(eager) Not well supported when graph building. From TensorFlow 1.x, `tf.compat.v1.enable_eager_execution()` should run first. Calling tf.saved_model.save in a loop when graph building from TensorFlow 1.x will add new save operations to the default graph each iteration. May not be called from within a function body. @end_compatibility """ if ops.inside_function(): raise AssertionError( "tf.saved_model.save is not supported inside a traced " "@tf.function. Move the call to the outer eagerly-executed " "context.") # pylint: enable=line-too-long if not isinstance(obj, base.Trackable): raise ValueError( "Expected a Trackable object for export, got {}.".format(obj)) options = options or save_options.SaveOptions() checkpoint_graph_view = _AugmentedGraphView(obj) if signatures is None: signatures = signature_serialization.find_function_to_export( checkpoint_graph_view) signatures = signature_serialization.canonicalize_signatures(signatures) signature_serialization.validate_saveable_view(checkpoint_graph_view) signature_map = signature_serialization.create_signature_map(signatures) checkpoint_graph_view.add_object( parent_node=checkpoint_graph_view.root, name_in_parent=signature_serialization.SIGNATURE_ATTRIBUTE_NAME, subgraph_root=signature_map) # Use _SaveableView to provide a frozen listing of properties and functions. # Note we run this twice since, while constructing the view the first time # there can be side effects of creating variables. _ = _SaveableView(checkpoint_graph_view) saveable_view = _SaveableView(checkpoint_graph_view) # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x # compatible (no sessions) and share it with this export API rather than # making a SavedModel proto and writing it directly. saved_model = saved_model_pb2.SavedModel() meta_graph_def = saved_model.meta_graphs.add() object_saver = util.TrackableSaver(checkpoint_graph_view) asset_info, exported_graph = _fill_meta_graph_def( meta_graph_def, saveable_view, signatures, options.namespace_whitelist) saved_model.saved_model_schema_version = ( constants.SAVED_MODEL_SCHEMA_VERSION) # So far we've just been generating protocol buffers with no I/O. Now we write # the checkpoint, copy assets into the assets directory, and write out the # SavedModel proto itself. utils_impl.get_or_create_variables_dir(export_dir) object_saver.save(utils_impl.get_variables_path(export_dir)) builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir) path = os.path.join( compat.as_str(export_dir), compat.as_str(constants.SAVED_MODEL_FILENAME_PB)) object_graph_proto = _serialize_object_graph( saveable_view, asset_info.asset_index) meta_graph_def.object_graph_def.CopyFrom(object_graph_proto) # Save debug info, if requested. if options.save_debug_info: graph_debug_info = _export_debug_info(exported_graph) file_io.atomic_write_string_to_file( os.path.join( utils_impl.get_or_create_debug_dir(export_dir), constants.DEBUG_INFO_FILENAME_PB), graph_debug_info.SerializeToString(deterministic=True)) # Note that this needs to be the last file operation when saving the # SavedModel. Users rely on checking saved_model_dir/saved_model.pb as an # indication that the SavedModel is completely written. file_io.atomic_write_string_to_file( path, saved_model.SerializeToString(deterministic=True)) # Clean reference cycles so repeated export()s don't make work for the garbage # collector. Before this point we need to keep references to captured # constants in the saved graph. ops.dismantle_graph(exported_graph)
def create_file_writer_v2(logdir, max_queue=None, flush_millis=None, filename_suffix=None, name=None, experimental_trackable=False): """Creates a summary file writer for the given log directory. Args: logdir: a string specifying the directory in which to write an event file. max_queue: the largest number of summaries to keep in a queue; will flush once the queue gets bigger than this. Defaults to 10. flush_millis: the largest interval between flushes. Defaults to 120,000. filename_suffix: optional suffix for the event file name. Defaults to `.v2`. name: a name for the op that creates the writer. experimental_trackable: a boolean that controls whether the returned writer will be a `TrackableResource`, which makes it compatible with SavedModel when used as a `tf.Module` property. Returns: A SummaryWriter object. """ if logdir is None: raise ValueError("Argument `logdir` cannot be None") inside_function = ops.inside_function() with ops.name_scope(name, "create_file_writer") as scope, ops.device("cpu:0"): # Run init inside an init_scope() to hoist it out of tf.functions. with ops.init_scope(): if context.executing_eagerly(): _check_create_file_writer_args( inside_function, logdir=logdir, max_queue=max_queue, flush_millis=flush_millis, filename_suffix=filename_suffix) logdir = ops.convert_to_tensor(logdir, dtype=dtypes.string) if max_queue is None: max_queue = constant_op.constant(10) if flush_millis is None: flush_millis = constant_op.constant(2 * 60 * 1000) if filename_suffix is None: filename_suffix = constant_op.constant(".v2") def create_fn(): # Use unique shared_name to prevent resource sharing in eager mode, but # otherwise use a fixed shared_name to allow SavedModel TF 1.x loading. if context.executing_eagerly(): shared_name = context.anonymous_name() else: shared_name = ops.name_from_scope_name(scope) # pylint: disable=protected-access return gen_summary_ops.summary_writer( shared_name=shared_name, name=name) init_op_fn = functools.partial( gen_summary_ops.create_summary_file_writer, logdir=logdir, max_queue=max_queue, flush_millis=flush_millis, filename_suffix=filename_suffix) if experimental_trackable: return _TrackableResourceSummaryWriter( create_fn=create_fn, init_op_fn=init_op_fn) else: return _ResourceSummaryWriter( create_fn=create_fn, init_op_fn=init_op_fn)
def __init__(self, # pylint: disable=super-init-not-called initial_value=None, trainable=None, caching_device=None, name=None, dtype=None, constraint=None, add_initializers_to=None, lifted_initializer_graph=None, **unused_kwargs): """Creates a variable. Args: initial_value: A `Tensor`, or Python object convertible to a `Tensor`, which is the initial value for the Variable. The initial value must have a shape specified unless `validate_shape` is set to False. Can also be a callable with no argument that returns the initial value when called. (Note that initializer functions from init_ops.py must first be bound to a shape before being used here.) trainable: If `True`, GradientTapes automatically watch uses of this Variable. caching_device: Optional device string or function describing where the Variable should be cached for reading. Defaults to the Variable's device. If not `None`, caches on another device. Typical use is to cache on the device where the Ops using the Variable reside, to deduplicate copying through `Switch` and other conditional statements. name: Optional name for the variable. Defaults to `'Variable'` and gets uniquified automatically. dtype: If set, initial_value will be converted to the given type. If None, either the datatype will be kept (if initial_value is a Tensor) or float32 will be used (if it is a Python object convertible to a Tensor). constraint: An optional projection function to be applied to the variable after being updated by an `Optimizer` (e.g. used to implement norm constraints or value constraints for layer weights). The function must take as input the unprojected Tensor representing the value of the variable and return the Tensor for the projected value (which must have the same shape). Constraints are not safe to use when doing asynchronous distributed training. add_initializers_to: if not None and not in legacy graph mode, the initializer tensor will be added to this map in addition to adding the assignment to the function. lifted_initializer_graph: FuncGraph to try to lift initializers to. Raises: ValueError: If the initial value is not specified, or does not have a shape and `validate_shape` is `True`. RuntimeError: If called outside of a function definition. """ if not ops.inside_function(): # If we've been init_scope()d out of the function definition nothing to do # here; we can't really do the capturing or conditional logic. resource_variable_ops.ResourceVariable.__init__( self, initial_value=initial_value, trainable=trainable, caching_device=caching_device, name=name, dtype=dtype, constraint=constraint) return with ops.init_scope(): self._in_graph_mode = not context.executing_eagerly() if initial_value is None: raise ValueError("initial_value must be specified.") init_from_fn = callable(initial_value) if constraint is not None and not callable(constraint): raise ValueError("The `constraint` argument must be a callable.") if isinstance(initial_value, trackable.CheckpointInitialValue): self._maybe_initialize_trackable() self._update_uid = initial_value.checkpoint_position.restore_uid initial_value = initial_value.wrapped_value if trainable is None: trainable = True self._trainable = trainable self._save_slice_info = None self._initial_value = None self._initializer_op = None self._is_initialized_op = None self._graph_element = None self._cached_value = None # Store the graph key so optimizers know how to only retrieve variables from # this graph. Guaranteed to be the same as the eager graph_key. self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as name: # pylint: disable=protected-access with ops.init_scope(): handle_name = ops._name_from_scope_name(name) unique_id = "%s_%d" % (handle_name, ops.uid()) shared_name = context.shared_name(unique_id) with ops.name_scope("Initializer"), ops.device(None): initial_value = ops.convert_to_tensor( initial_value() if init_from_fn else initial_value, name="initial_value", dtype=dtype) with ops.init_scope(): self._handle = resource_variable_ops.eager_safe_variable_handle( initial_value=initial_value, shared_name=shared_name, name=name, graph_mode=self._in_graph_mode) self._shape = initial_value.shape self._unique_id = unique_id self._handle_name = handle_name + ":0" self._dtype = initial_value.dtype.base_dtype self._constraint = constraint assert initial_value is not None if self._in_graph_mode: with ops.init_scope(): outer_graph = ops.get_default_graph() func_graph = ops.get_default_graph() function_placeholders = ( func_graph.inputs + func_graph.internal_captures) placeholder_ops = set( [tensor.op for tensor in function_placeholders]) lifted_initializer = lift_to_graph.lift_to_graph( [initial_value], outer_graph, disallowed_placeholders=placeholder_ops)[initial_value] with ops.init_scope(): self._initial_value = lifted_initializer with ops.name_scope("IsInitialized"): self._is_initialized_op = ( resource_variable_ops.var_is_initialized_op(self._handle)) if initial_value is not None: with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): self._initializer_op = resource_variable_ops.assign_variable_op( self._handle, lifted_initializer, name=n) with ops.name_scope("Read"), ops.colocate_with(self._handle): # Manually assign reads to the handle's device to avoid log # messages. with ops.device(self._handle.device): value = self._read_variable_op() self._graph_element = value ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, self) else: if add_initializers_to is not None: add_initializers_to[self] = initial_value def assign_fn(): with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): resource_variable_ops.assign_variable_op( self._handle, initial_value, name=n) # Returning values to keep tf.cond happy. return ops.convert_to_tensor(1) def not_assign_fn(): return ops.convert_to_tensor(0) # Note: this cond is always guaranteed to run because we're inside a # defun which will insert automatic control dependencies. control_flow_ops.cond( resource_variable_ops.var_is_initialized_op(self._handle), not_assign_fn, assign_fn) # After the handle has been created, set up a way to clean it up when # executing eagerly. We'll hold the only reference to the deleter, so that # when this object is garbage collected the deleter will be too. This # means ResourceVariables can be part of reference cycles without those # cycles being uncollectable. if not self._in_graph_mode: self._handle_deleter = resource_variable_ops.EagerResourceDeleter( handle=self._handle, handle_device=self._handle.device) self._cached_shape_as_list = None
def __init__(self, dataset=None, devices=None, max_buffer_size=1, prefetch_buffer_size=1, source_device="/cpu:0", components=None, element_spec=None): """Constructs an owned MultiDeviceIterator object. Args: dataset: The input dataset to be iterated over. devices: The list of devices to fetch data to. max_buffer_size: Maximum size of the host side per device buffer to keep. prefetch_buffer_size: if > 0, then we setup a buffer on each device to prefetch into. source_device: The host device to place the `dataset` on. In order to prevent deadlocks, if the prefetch_buffer_size is greater than the max_buffer_size, we set the max_buffer_size to prefetch_buffer_size. components: Tensor components to construct the MultiDeviceIterator from. element_spec: A (nested) structure of `tf.TypeSpec` objects that represents the type specification of elements of the iterator. Raises: RuntimeError: If executed in graph mode or outside of function building mode. """ if not context.executing_eagerly() and not ops.inside_function(): raise RuntimeError( "OwnedMultiDeviceIterator is only supported inside of " "tf.function or when eager execution is enabled.") if devices is None: raise ValueError("`devices` must be provided") error_message = "Either `dataset` or both `components` and " "`element_spec` need to be provided." if dataset is None: if (components is None or element_spec is None): raise ValueError(error_message) self._element_spec = element_spec self._devices = devices self._source_device = source_device self._multi_device_iterator_resource = components[0] self._deleter = components[1] self._device_iterators = components[2:] iterator_handles = [] for it in self._device_iterators: iterator_handles.append(it._iterator_resource) # pylint: disable=protected-access else: if (components is not None or element_spec is not None): raise ValueError(error_message) options = dataset_ops.Options() options.experimental_distribute.num_devices = len(devices) dataset = dataset.with_options(options) dataset = dataset._apply_options() # pylint: disable=protected-access self._element_spec = dataset.element_spec experimental_slack = dataset.options().experimental_slack self._devices = devices self._source_device = source_device source_device_tensor = ops.convert_to_tensor(self._source_device) if prefetch_buffer_size > max_buffer_size: max_buffer_size = prefetch_buffer_size # Create the MultiDeviceIterator. with ops.device(self._source_device): self._multi_device_iterator_resource, self._deleter = ( gen_dataset_ops.anonymous_multi_device_iterator( devices=self._devices, **dataset._flat_structure)) # pylint: disable=protected-access # The incarnation ID is used to ensure consistency between the # per-device iterators and the multi-device iterator. incarnation_id = gen_dataset_ops.multi_device_iterator_init( dataset._variant_tensor, # pylint: disable=protected-access self._multi_device_iterator_resource, max_buffer_size=max_buffer_size) prototype_device_datasets = [] for i, device in enumerate(self._devices): with ops.device(device): ds = _PerDeviceGenerator( i, self._multi_device_iterator_resource, incarnation_id, source_device_tensor, dataset.element_spec) prototype_device_datasets.append(ds) # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to # initialize the device side of the pipeline. This would allow the # MultiDeviceIterator to choose, for example, to move some transformations # into the device side from its input. It might be useful in rewriting. # Create the per device iterators. self._device_iterators = [] iterator_handles = [] for i, device in enumerate(self._devices): with ops.device(device): ds = _create_device_dataset(prototype_device_datasets[i], incarnation_id, prefetch_buffer_size, experimental_slack) iterator = iter(ds) self._device_iterators.append(iterator) iterator_handles.append(iterator._iterator_resource) # pylint: disable=protected-access self._resource_deleter = MultiDeviceIteratorResourceDeleter( multi_device_iterator=self._multi_device_iterator_resource, iterators=iterator_handles, device=self._source_device, deleter=self._deleter)