def testEagerMultiLearnerCheckpointCompatibility(self): self.assertTrue(tf.executing_eagerly()) cfg = model_registry.GetParams('test.LinearModelParams', 'Train') mdl = cfg.Instantiate() with py_utils.GradientTape(persistent=True): mdl.ConstructFPropBPropGraph() eager_v1_logdir = os.path.join(self.get_temp_dir(), 'eager_v1') eager_v2_logdir = os.path.join(self.get_temp_dir(), 'eager_v2') checkpointer.EagerCheckpointerV1(eager_v1_logdir, mdl).Save(gsteps=0) checkpointer.EagerCheckpointerV2(eager_v2_logdir, mdl).Save(gsteps=0) eager_v1_keys = _GetCheckpointKeys( os.path.join(eager_v1_logdir, 'ckpt_V1', 'ckpt-00000000')) eager_v2_keys = _GetCheckpointKeys( os.path.join(eager_v2_logdir, 'ckpt_V2', 'ckpt-0')) # Expecting two more variables in V2 checkpoints: # _CHECKPOINTABLE_OBJECT_GRAPH # save_counter self.assertEqual(len(eager_v1_keys) + 2, len(eager_v2_keys)) # pylint:disable=g-generic-assert py_utils.SetEagerMode(False) self.assertFalse(tf.executing_eagerly()) graph_logdir = os.path.join(self.get_temp_dir(), 'graph') os.mkdir(graph_logdir) with self.session(graph=tf.Graph()) as sess: mdl = cfg.Instantiate() for lrn in mdl.GetTask().learners: lrn.optimizer.params.clear_variable_scope = False mdl.ConstructFPropBPropGraph() sess.run(tf.global_variables_initializer()) checkpointer.Checkpointer(graph_logdir, mdl).Save(sess) graph_keys = _GetCheckpointKeys(os.path.join(graph_logdir, 'ckpt')) self.assertEqual(eager_v1_keys, graph_keys)
def GetTensorName(tensor, name_eager=None, i_eager=None): """Returns tensor name. It is useful for compatibility with eager mode. Args: tensor: tensor name_eager: additional string to append in eager mode i_eager: additional index to append in eager mode Returns: tensor.name in session mode, or concatenation of name_eager, i_eager in eager mode """ if not tf.executing_eagerly(): tensor_name = tensor.name else: if name_eager and i_eager: tensor_name = f'[eager]_{name_eager}_{i_eager}' elif name_eager: tensor_name = f'[eager]_{name_eager}' elif i_eager: tensor_name = f'[eager]_{i_eager}' else: tensor_name = '[eager]' return tensor_name
def testEagerEMACheckpointCompatibility(self): self.assertTrue(tf.executing_eagerly()) cfg = model_registry.GetParams('test.LinearModelParams', 'Train') # Use non-zero learning rate so that the weights are updated cfg.task.train.learner[0].learning_rate = 0.1 cfg.task.train.learner[1].learning_rate = 0.1 eager_v1_logdir = os.path.join(self.get_temp_dir(), 'eager_v1') eager_v2_logdir = os.path.join(self.get_temp_dir(), 'eager_v2') mdl = cfg.Instantiate() @tf.function def _Update(): with py_utils.GradientTape(persistent=True): mdl.ConstructFPropBPropGraph() # Step 1 _Update() # Save V1 checkpoints at step 1. ckpt_v1 = checkpointer.EagerCheckpointerV1(eager_v1_logdir, mdl) ckpt_v1.Save(gsteps=1) ema = mdl.ema model_to_ema_map = _GetModelEMAVariablePairs(mdl, ema) model_to_ema_map_snapshot_step1 = { k: v.value() for k, v in model_to_ema_map.items() } # Step 2 _Update() # Save V2 checkpoints at step 2. ckpt_v2 = checkpointer.EagerCheckpointerV2(eager_v2_logdir, mdl) ckpt_v2.Save(gsteps=2) model_to_ema_map = _GetModelEMAVariablePairs(mdl, ema) model_to_ema_map_snapshot_step2 = { k: v.value() for k, v in model_to_ema_map.items() } with cluster_factory.SetEval(True): # Restores variables to values saved in `eager_v1_logdir` ckpt_v1.Restore() # Verify that the EMA variables from V1 checkpoints at step 1 successfully # overwrite the model variables. for v in mdl.variables: if v.ref() in model_to_ema_map_snapshot_step1: self.assertAllEqual(v, model_to_ema_map_snapshot_step1[v.ref()]) with cluster_factory.SetEval(True): # Restores variables to values saved in `eager_v2_logdir` ckpt_v2.Restore() # Verify that the EMA variables from V2 checkpoints at step 2 successfully # overwrite the model variables. for v in mdl.variables: if v.ref() in model_to_ema_map_snapshot_step2: self.assertAllEqual(v, model_to_ema_map_snapshot_step2[v.ref()])
def Reset(self, sess): if self._dataset: if tf.executing_eagerly(): self._iterator = { key: iter(ds) for key, ds in self._dataset.items() } else: sess.run([it.initializer for it in self._iterator.values()])
def _InitIterator(self): if self.host_id in self._dataset: return with py_utils.GlobalStepContext(None): # Hide global_step tensor from being captured by dataset function. ds = self.GetDataset() ds.options().experimental_deterministic = False self._dataset[self.host_id] = ds if tf.executing_eagerly(): it = iter(ds) else: it = tf.data.make_initializable_iterator(ds) self._iterator[self.host_id] = it
def AddEvalMetric(self, name, value, weight, raise_if_already_added=True): """Adds a metric to the eval metrics. Args: name: A python string. The name of the metric. value: A scalar Tensor. weight: A scalar Tensor. raise_if_already_added: If the metric already exists, raise a ValueError. Raises: ValueError: if `name` is already defined. """ if name in self._eval_metrics and not tf.executing_eagerly(): if raise_if_already_added: raise ValueError('Metric %s has already been defined.' % name) self._eval_metrics[name] = (value, weight)
def PlotSequenceFeatures(plots, name, **kwargs): """Plots a stack of sequence features. Args: plots: A list of tuple (tensor, seq_len), as returned by PrepareSequenceForPlot(). name: A string for the caption of the plot. **kwargs: Keyword arguments passed to AddSubplot(). """ if not _ShouldAddSummary(): return with plot.MatplotlibFigureSummary(name, figsize=(8, len(plots) * 3.5)) as fig: for i, (tensor, seq_len) in enumerate(plots): if not tf.executing_eagerly(): tensor_name = tensor.name else: tensor_name = f'[eager]_{name}_{i}' fig.AddSubplot([tensor, seq_len], TrimPaddingAndPlotSequence, title=tensor_name, **kwargs)
def __init__(self, train_cfg, ps_params_dict, *args, **kwargs): """Construct an ExecutorTpu BaseRunner. Args: train_cfg: SingleTaskModelParams or MultiTaskModelParams ps_params_dict: A dict of top-level task name -> ProgramSchedule params, if train_cfg is a SingleTaskModelParams, we expect only one entry. *args: List args to pass through to BaseRunner. **kwargs: keyword args to pass through to BaseRunner. """ if py_utils.IsEagerMode(): assert tf.executing_eagerly() tf.logging.info(f'FLAGS.tf_master: {FLAGS.tf_master}') # Connect to the TPU runtime. resolver = tf.distribute.cluster_resolver.TPUClusterResolver( FLAGS.tf_master, job_name=FLAGS.worker_job[len('/job:'):]) tf.config.experimental_connect_to_cluster(resolver) super().__init__(train_cfg, *args, **kwargs) data_parallelism = self._cluster.num_splits_per_client assert data_parallelism num_devices_per_split = self._cluster.num_devices_per_split tf.logging.info('data_parallelism: %d, num_devices_per_split: %d', data_parallelism, num_devices_per_split) self.task_scheduler = None self._checkpoint_dir = os.path.join(self._logdir, 'train') self._variable_renaming_rules = [] self._ml_perf = None # If this is a multi-task model, grab the params for the TaskScheduler. if issubclass(train_cfg.cls, base_model.SingleTaskModel): tf.logging.info('single_task_model') assert len(ps_params_dict) == 1 self._model_task_name = list(ps_params_dict.keys())[0] self._single_task_mode = True elif issubclass(train_cfg.cls, base_model.MultiTaskModel): tf.logging.info('multi_task_model') if issubclass(train_cfg.cls, multitask_model.RegExSharedVariableModel): self._variable_renaming_rules = train_cfg.variable_renaming_rules if train_cfg.task_schedule is None: task_schedule_params = task_scheduler.ConstantScheduler.Params( ) task_schedule_params.task_probs = sorted( list(train_cfg.task_probs.IterParams())) else: task_schedule_params = train_cfg.task_schedule self.task_scheduler = task_schedule_params.Instantiate() self._single_task_mode = False else: tf.logging.fatal( 'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel', train_cfg.cls) tf.logging.info('train_cfg.cls: %s', train_cfg.cls) self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir, 'trainer_params.txt') self._WriteToLog( text_format.MessageToString(train_cfg.ToProto(), as_utf8=True), self._checkpoint_dir, 'trainer_params.pbtxt') if self._ml_perf is not None: self._ml_perf_log = True mlp_log.mlperf_print(key='benchmark', value=self._ml_perf.benchmark_name) else: self._ml_perf_log = False train_cfg = self.params @py_utils.RetryOnTransientTfError() def _WaitTillInit(job=None): """Wait until the model is ready.""" try: if py_utils.IsEagerMode(): topology = tf.tpu.experimental.initialize_tpu_system( resolver) else: # tpu.initialize_system() is called with None as embedding_config, as # embedding_config is not available yet. Later in _Loop, it is called # with the correct embedding_config. Since it cannot be called twice # in the same graph with different embedding_config, we use a # dummy_graph here. dummy_graph = tf.Graph() with dummy_graph.as_default(): tpu_initialize_system_op = tf.tpu.initialize_system( embedding_config=None, job=job) with self._GetSession(graph=dummy_graph) as sess: topology = sess.run(tpu_initialize_system_op) if train_cfg.train.tpu_computation_shape is None: computation_shape = py_utils.ComputationShape( num_devices_per_split, topology) else: computation_shape = train_cfg.train.tpu_computation_shape assert num_devices_per_split == np.prod(computation_shape) if train_cfg.train.tpu_device_order_mode is None: self.device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=computation_shape, num_replicas=data_parallelism) else: self.device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=computation_shape, num_replicas=data_parallelism, device_order_mode=train_cfg.train.tpu_device_order_mode ) py_utils.SetTpuDeviceAssignment(self.device_assignment, job) tf.logging.info('device_assignment.core_assignment: %s', str(self.device_assignment.core_assignment)) tf.logging.info( 'device_assignment.topology.device_coordinates: %s', str(self.device_assignment.topology.device_coordinates)) except py_utils.transient_tf_errors as e: tf.logging.info('TPU initialization failed: %s', e) raise if self._ml_perf_log: mlp_log.mlperf_print(key='init_start', value=None) if len(self._cluster.all_worker_names) > 1: for worker in self._cluster.all_worker_names: _WaitTillInit(worker) else: _WaitTillInit(None) shared_model = self._MaybeConstructSharedModel(train_cfg) self._program_schedule_dict = {} self._programs = [] self._ckpt_programs = [] self._checkpoint_to_load = None with self._cluster: # Create the ExponentialMovingAverage singleton shared by all programs, if # applicable. ema = py_utils.CreateEMAForModel(train_cfg, self._global_step_var) for task_string, program_schedule_params in ps_params_dict.items(): program_schedule_params.logdir = self._logdir program_schedule_params.num_splits_per_client = data_parallelism program_schedule_params.task_name = task_string # If the model was created above, we'll inject it here as a # shared_model. ps = program_schedule_params.Instantiate( shared_model=shared_model, trial=self._trial, ema=ema, tf_master=self._tf_master) self._program_schedule_dict[task_string] = ps tf.logging.info('program_schedule_params: %s', program_schedule_params.ToText()) self._programs += ps.Programs() if ps.train_program: self._ckpt_programs.append(ps.train_program) else: self._ckpt_programs += ps.Programs() if program_schedule_params.ml_perf.benchmark_name is not None: self._ml_perf = program_schedule_params.ml_perf if ('checkpoint_to_load' in program_schedule_params and program_schedule_params.checkpoint_to_load): if (self._checkpoint_to_load and (self._checkpoint_to_load != program_schedule_params.checkpoint_to_load)): raise ValueError( f'Multiple values found for checkpoint_to_load: ' f'{self._checkpoint_to_load}, ' f'{program_schedule_params.checkpoint_to_load}.') self._checkpoint_to_load = program_schedule_params.checkpoint_to_load tf.logging.info('num_programs: %d', len(self._programs)) # When running in a vizier trainer, the executor reports infeasiable runs # in case of errors. The programs report metrics and normal completions. for program in self._programs: if program._should_report_metrics: self._should_report_metrics = True with self._cluster, tf.container( self._container_id), contextlib.ExitStack() as stack: if not py_utils.IsEagerMode(): stack.enter_context(self._graph.as_default()) if FLAGS.use_tpu_mirrored_vars: resolver = tf.distribute.cluster_resolver.TPUClusterResolver( FLAGS.tf_master, job_name=FLAGS.worker_job[len('/job:'):]) self._tpu_strategy = tf.distribute.experimental.TPUStrategy( resolver, device_assignment=self.device_assignment) stack.enter_context(self._tpu_strategy.scope()) stack.enter_context( tpu_strategy._TPUReplicaContext(self._tpu_strategy)) else: stack.enter_context(tf.device(self._cluster.GetPlacer())) if FLAGS.pdb_on_exception: stack.enter_context(pdb_wrapper.catch_post_mortem()) with py_utils.VariableStore(), py_utils.VariableRenameScope( self._variable_renaming_rules): # `BuildTpuSubgraph` has to be called before checkpoint restore, so that # the optimizer slot variables are guaranteed to be initialized before # they get loaded. Otherwise, the optimizers' slot variables will not # be properly loaded when V1 checkpoint is used. for program in self._programs: program.BuildTpuSubgraph() py_utils.ClearTpuSummaryTensors() if not py_utils.IsEagerMode(): self._initialize_tables = tf.tables_initializer() self._initialize_local_vars = tf.local_variables_initializer() self._initialize_global_vars = tf.global_variables_initializer( ) checkpointer_models = [ program.GetModel() for program in self._ckpt_programs ] if py_utils.IsEagerMode(): if FLAGS.use_v2_checkpoints_in_eager: self._checkpointer = checkpointer.EagerCheckpointerV2( self._checkpoint_dir, models=checkpointer_models, init_op=None, train_params=train_cfg.train, save_only=False) else: self._checkpointer = checkpointer.EagerCheckpointerV1( self._checkpoint_dir, models=checkpointer_models, init_op=None, train_params=train_cfg.train, save_only=False) else: self._checkpointer = checkpointer.Checkpointer( self._checkpoint_dir, models=checkpointer_models, init_op=self._initialize_global_vars, train_params=train_cfg.train, save_only=False) for program in self._programs: program.SetStatusMessageFn(self._SetStatusMessage) tpu_embedding_collection = ( tpu_embedding_layers.TpuEmbeddingCollection.Get()) self._load_ops = tpu_embedding_collection.load_ops self._retrieve_ops = tpu_embedding_collection.retrieve_ops self._tpu_embedding = tpu_embedding_collection.tpu_embedding
def Initialize(self, sess): if not tf.executing_eagerly(): self.Reset(sess) super().Initialize(sess)
def AddPerExampleTensor(self, name, value): if name in self._per_example and not tf.executing_eagerly(): raise ValueError('Metric %s has already been defined.' % name) self._per_example[name] = value
def InitDevicesEager(self): assert tf.executing_eagerly() self._session_devices = [d.name for d in tf.config.list_logical_devices()] tf.logging.info('InitDevices %s' % sorted(self._session_devices))
import lingvo.compat as tf from lingvo.core import cluster_factory from lingvo.core import py_utils import numpy as np FLAGS = tf.flags.FLAGS # Enable tf.function when eager execution is on-by-default, which is the case # when: # - the test target doesn't depend on the disable_tf2 target, and # - --define=tf_api_version=1 is not specified during the build. # # TODO(laigd): remove TF version check when 312743821 and 313682500 are in the # release. if tf.executing_eagerly() and tf.compat.v1.__version__ >= '2.3.0': try: FLAGS.if_use_tf_function = True FLAGS.while_loop_use_tf_function = True FLAGS.call_defun_use_tf_function = True except tf.flags.UnrecognizedFlagError: pass # Disable eager execution for all tests. tf.disable_eager_execution() tf.flags.DEFINE_boolean( 'update_goldens', False, 'Update the goldens, rather than diffing against them.')
def testSomeTFSymbols(self): self.assertFalse(tf.executing_eagerly()) self.assertIsNotNone(tf.logging) self.assertIsNotNone(tf.flags) self.assertIs(tf.Defun, function.Defun)
def Export(cls, model_cfg, model_task_name=None, device_options=InferenceDeviceOptions( device='', retain_device_placement=False, var_options=None, gen_init_op=True, dtype_override=None, fprop_dtype_override=None), freeze_checkpoint=None, freeze_defaults=False, export_path=None, subgraph_filter=None, random_seed=None, disable_packed_input=True, prune_graph=True, export_graph_collections=False): """Exports a InferenceGraph proto with piecewise subgraphs. Sets FLAGS.enable_asserts to False unless user explicitly sets it to True. Note: Enable FLAGS.pin_vars_to_cpu (default false) to make weight-sharing and multi-core inference on TPUs work properly. Args: model_cfg: a Params instance as returned by model_registry.GetParams(modelname, 'Test') or model_params.Model(). model_task_name: The task to generate an inference graph for. Should be None for single-task models. device_options: Device options for the accelerator used for serving. freeze_checkpoint: The checkpoint to load. Loads and freezes the model if given. freeze_defaults: Default initializes the graph and freeze. Useful for early testing of downstream tools without having a checkpoint. export_path: If not None, write the inference graph in ASCII to this path. subgraph_filter: A string or a list of subgraph names. If not None or empty, export only this list of inference subgraphs. random_seed: Fixes the random seed in the exported inference graph. disable_packed_input: Disable packed input for inference writing purposes. prune_graph: If true, prune the graph to just the parts we need. export_graph_collections: If true, export graph collections to the InferenceGraph proto. Returns: InferenceGraph proto. Raises: ValueError: if the model does not support the listed subgraphs. """ if py_utils.IsEagerMode(): raise ValueError('InferenceGraph exporter does not work in Eager mode.') assert issubclass(model_cfg.cls, base_model.BaseModel) if device_options.dtype_override and device_options.fprop_dtype_override: raise ValueError( 'device_options{dtype_override,fprop_dtype_override) can not both be' 'set.') if subgraph_filter and not isinstance(subgraph_filter, (tuple, list)): subgraph_filter = [subgraph_filter] # Disable assertions unless user explicitly enables it. if FLAGS['enable_asserts'].using_default_value: FLAGS.enable_asserts = False # TODO(laurenzo): Work out how much we need to specify here in terms of # cluster configuration. cls._SetClusterParams(model_cfg.cluster, device_options) # Configure the model. model_cfg.random_seed = random_seed model_cfg.is_inference = True if disable_packed_input: def _DisablePackedInput(task): if (_ParamExists(task, 'encoder') and _ParamExists(task.encoder, 'packed_input')): task.encoder.packed_input = False if (_ParamExists(task, 'decoder') and _ParamExists(task.decoder, 'packed_input')): task.decoder.packed_input = False if issubclass(model_cfg.cls, base_model.MultiTaskModel): for _, task_param in model_cfg.task_params.IterParams(): _DisablePackedInput(task_param) else: _DisablePackedInput(model_cfg.task) tf.logging.debug('Model %s params:', model_cfg.name) for line in model_cfg.ToText().split('\n'): tf.logging.debug('%s', line) # Instantiate the graph. graph = tf.Graph() with graph.as_default(): tf.random.set_seed(random_seed) cluster = model_cfg.cluster.Instantiate() device = cluster.GetPlacer() tpu_const_scope = _DummyScope() if (IsTpu(device_options) and device_options.var_options == 'AS_CONSTANTS'): # Do not specify devices for variables if we are marking them as # constants. device = '' tpu_const_scope = ConstGuaranteeScope() with cluster, tf.device(device), tpu_const_scope: bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations( device_options) if bfloat16_override: py_utils.UpdateDtype(model_cfg, tf.bfloat16) py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16) act_bfloat16_override = ShouldForceBfloat16ForActivations( device_options) if act_bfloat16_override: py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16) # Hard-code TPU-related flags prior to instantiating model. old_enable_asserts = FLAGS.enable_asserts old_xla_device = FLAGS.xla_device if IsTpu(device_options): FLAGS.enable_asserts = False FLAGS.xla_device = 'tpu' try: mdl = model_cfg.Instantiate() task = mdl.GetTask(model_task_name) variables_to_restore = ( _MakeVariableDictionary(tf.global_variables()) if not mdl.ema else mdl.ema.variables_to_restore(mdl.variables_for_ema)) if bfloat16_override: saver_var_spec = ( bfloat16_variables .get_saver_spec_for_variables_with_bf16_overrides( variables_to_restore)) # For TPU embedding layers, if the table explicitly specifies the # inference dtype as bfloat16, the variables in the checkpoint must # already be in bfloat16, so we change back to bfloat16 to avoid # dtype mismatch. for var_name in (tpu_embedding_layers.TpuEmbeddingCollection.Get() .inference_with_bfloat16_var_names): saver_var_spec[var_name] = variables_to_restore[var_name] else: saver_var_spec = variables_to_restore saver = tf.train.Saver(saver_var_spec) tf.variables_initializer( tf.global_variables(), name='init_all_variables') if IsTpu(device_options) and device_options.gen_init_op: tf.group(tf.tpu.initialize_system(), name='tpu_init_op') if freeze_checkpoint or freeze_defaults: # Replace variables with tensors using tf.identity in theta before # freezing to avoid the graph referencing types of DT_RESOURCE. def AddIdentityToTheta(layer): # pylint: disable=protected-access layer._private_theta = py_utils.Transform(tf.identity, layer._private_theta) # pylint: enable=protected-access layer.children.Transform(AddIdentityToTheta) AddIdentityToTheta(task) inference_graph_proto = inference_graph_pb2.InferenceGraph() subgraphs_proto = task.Inference() if isinstance(subgraphs_proto, dict): subgraphs_proto = ConvertSubgraphDictToProto(subgraphs_proto) for name, subgraph in subgraphs_proto.subgraphs.items(): if not subgraph_filter or name in subgraph_filter: inference_graph_proto.subgraphs[name].CopyFrom(subgraph) if not inference_graph_proto.subgraphs and subgraph_filter: raise ValueError( f'Subgraph filters {subgraph_filter} filtered out all ' 'subgraphs. Defined subgraphs: ' f'{list(subgraphs_proto.subgraphs.keys())}') # Yes, graph collections are bad, however this seems to be the # easiest way to get this assets registered from # TextFileInitializer. assets_collection = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.ASSET_FILEPATHS) for asset in assets_collection: if asset.op.type == 'Const' and asset.op.get_attr( 'dtype') == tf.dtypes.string: constant_value = asset.op.get_attr('value') if constant_value.string_val: tf.logging.info('Found asset file_path: %s', constant_value.string_val[0]) asset_file_def = inference_graph_proto.asset_file_def.add() asset_file_def.tensor_info.name = asset.name asset_file_def.filename = constant_value.string_val[0] # Add a table init op and global variable init op to the graph. # Tables can be declared anywhere in the graph, so this op has to be # added last. tf.tables_initializer(name='init_all_tables') finally: # Reset TPU-related flags after model instantiation. FLAGS.enable_asserts = old_enable_asserts FLAGS.xla_device = old_xla_device tf.logging.info('Graph contains ops: %r', [op.name for op in graph.get_operations()]) # Collection defs if not tf.executing_eagerly(): if export_graph_collections: meta_graph = tf.train.export_meta_graph(graph=graph) for key in meta_graph.collection_def: tf.logging.info('copying collection %s', key) inference_graph_proto.collection_def[key].CopyFrom( meta_graph.collection_def[key]) else: tf.logging.warning('Not exporting collection defs ' 'since operating in eager mode.') # Freezing. if freeze_defaults or freeze_checkpoint: output_op_names = GetOutputOpNames( graph, inference_graph_proto, preserve_colocation_nodes=False, preserve_saver_restore_nodes=False) if cls._DeviceSupportsFreezing(device_options): raise ValueError('freeze_checkpoint cannot be used with device ' + device_options.device) if freeze_checkpoint: tf.logging.info('Freezing graph from checkpoint: %s', freeze_checkpoint) graph_def = _FreezeGraphFromCheckpoint(graph, saver, freeze_checkpoint, output_op_names) elif freeze_defaults: tf.logging.info('Default initializing graph and freezing.') graph_def = _FreezeDefaults(graph, output_op_names) else: inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def()) graph_def = graph.as_graph_def() if prune_graph: output_op_names = GetOutputOpNames(graph, inference_graph_proto) # Prune the graph to just the parts we need. # To support restoring, we have to not prune out the restore node. output_op_names.append('init_all_tables') output_op_names.append('init_all_variables') output_op_names.append('save/control_dependency') output_op_names.append('save/restore_all') if IsTpu(device_options) and device_options.gen_init_op: output_op_names.append('tpu_init_op') tf.logging.info('Pruning graph to output ops: %r', output_op_names) graph_def = tf.compat.v1.graph_util.extract_sub_graph( graph_def, output_op_names) if not device_options.retain_device_placement: # Clear the device so that the runtime can choose. tf.logging.info('Clearing device placement for: %s', device_options.device) for node in graph_def.node: node.ClearField('device') for function in graph_def.library.function: for node_def in function.node_def: node_def.ClearField('device') inference_graph_proto.graph_def.CopyFrom(graph_def) if export_path: with tf.io.gfile.GFile(export_path, 'w') as f: f.write(text_format.MessageToString(inference_graph_proto)) return inference_graph_proto
def AddAttentionSummaryBatchMajor(name, attention_tensors, src_paddings, tgt_paddings, transcripts=None, max_outputs=3): """Adds an image summary showing the attention probability matrix and state. As opposed to AddAttentionSummary() takes all tensors with batch dimension in axis 0. Args: name: Summary name. attention_tensors: A list of 3D tensors shaped [batch_size, target_len, source_len] where attention[b, i, j] is the probability for the i-th output attending to the j-th input for element b in the batch. src_paddings: A tensor of binary paddings shaped [batch, source_len] for the source sequence. Or a list of tensors of the same length as attention_tensors with a separate paddings for each entry in attention_tensors. tgt_paddings: A tensor of binary paddings shaped [batch, target_len] for the target sequence. Or a list of tensors of the same length as attention_tensors with a separate paddings for each entry in attention_tensors. transcripts: Optional, transcripts shaped [batch, source_len] for the source sequence. max_outputs: Integer maximum number of elements of the batch to plot. """ def VerifyLen(paddings): length = len(paddings) if isinstance(paddings, list) else 1 if length != 1 and length != len(attention_tensors): raise ValueError('Bad length of paddings list {}'.format(length)) VerifyLen(src_paddings) VerifyLen(tgt_paddings) # Verify shapes. for i, attention_tensor in enumerate(attention_tensors): src, tgt = src_paddings, tgt_paddings src = src[0 if len(src) == 1 else i] if isinstance(src, list) else src tgt = tgt[0 if len(tgt) == 1 else i] if isinstance(tgt, list) else tgt tgt_shape = py_utils.GetShape(tgt) if not tf.executing_eagerly(): attention_tensor_name = attention_tensor.name else: attention_tensor_name = f'[eager]_{name}_{i}' attention_tensors[i] = tf.identity( py_utils.with_dependencies([ py_utils.assert_equal( py_utils.GetShape(attention_tensor), tgt_shape[:2] + [py_utils.GetShape(src)[1]] + tgt_shape[2:]) ], attention_tensor), re.sub(':.*$', '', attention_tensor_name)) if not _ShouldAddSummary(): return def ToLengths(paddings): paddings = paddings if isinstance(paddings, list) else [paddings] return [SequenceLength(p) for p in paddings] def Get(lengths, i): return lengths[0 if len(lengths) == 1 else i] src_lens = ToLengths(src_paddings) tgt_lens = ToLengths(tgt_paddings) with plot.MatplotlibFigureSummary(name + '/Attention', max_outputs=max_outputs, gridspec_kwargs={'hspace': 0.3}) as fig: for n, atten in enumerate(attention_tensors): # Diagnostic metric that decreases as attention picks up. max_entropy = tf.math.log(tf.cast(Get(src_lens, n), tf.float32)) max_entropy = tf.expand_dims(tf.expand_dims(max_entropy, -1), -1) atten_normalized_entropy = -atten * tf.math.log( atten + 1e-10) / max_entropy scalar(name + '/Attention/average_normalized_entropy/%d' % n, tf.reduce_mean(atten_normalized_entropy)) args = [atten, Get(src_lens, n), Get(tgt_lens, n)] if transcripts is not None and n == 0: args.append(transcripts) if not tf.executing_eagerly(): atten_name = atten.name else: atten_name = f'[eager]_{name}_{n}' fig.AddSubplot(args, TrimPaddingAndPlotAttention, title=atten_name, xlabel='Input', ylabel='Output')