def RestoreIfNeeded(self, sess): """If vars are not initialized, restore frome checkpoint. Args: sess: tf.Session. """ assert not self._save_only uninitialized_var_names = list(sess.run(self._uninitialized_vars)) if not uninitialized_var_names: return tf.logging.info('Uninitialized var list: %s ', uninitialized_var_names) if self._Restore(sess): return if (not any(task.params.train.init_from_checkpoint_rules for task in self._model_tasks) and not self._params.train.init_from_checkpoint_rules): tf.logging.info('Initialize ALL variables: %s', uninitialized_var_names) sess.run([self._initialize_vars]) tf.logging.info('Initialize variables done.') return # There was a race in local run. Another thread will get unblocked once # _initialize_all is called. OverrideVarsFromCheckpoints # might not happen at the right time. for task in self._model.tasks: tp = task.params.train if tp.init_from_checkpoint_rules: tf.logging.info('OverrideVarsFromCheckpoints %s', tp.init_from_checkpoint_rules) py_utils.OverrideVarsFromCheckpoints( sess, self._vars, tp.init_from_checkpoint_rules) if self._params.train.init_from_checkpoint_rules: tp = self._params.train tf.logging.info('OverrideVarsFromCheckpoints %s', tp.init_from_checkpoint_rules) py_utils.OverrideVarsFromCheckpoints(sess, self._vars, tp.init_from_checkpoint_rules) uninitialized_var_names = list(sess.run(self._uninitialized_vars)) if not uninitialized_var_names: return # uninitialized_var_names is a list of strings without ":0" suffix. # tf.report_uninitialized_variables returns binary strings. assert all( isinstance(s, six.binary_type) for s in uninitialized_var_names) # Need to retrieve vars, removing ":0" suffix from names. uninitialized_vars = [ v for v in self._vars if six.ensure_binary(v.name[:-2]) in uninitialized_var_names ] tf.logging.info('Initialize variables: %s', [v.name for v in uninitialized_vars]) sess.run(tf.variables_initializer(uninitialized_vars))
def _create_session(self, *args, **kwargs): sess = super()._create_session(*args, **kwargs) with sess.graph.as_default(): # Ensure the global_step variable is created in every new session. global_step = py_utils.GetOrCreateGlobalStepVar() sess.run( tf.cond(tf.is_variable_initialized(global_step), tf.no_op, lambda: tf.variables_initializer([global_step]))) return sess
def Export(cls, model_cfg, model_task_name=None, device_options=InferenceDeviceOptions( device='', retain_device_placement=False, var_options=None, gen_init_op=True, dtype_override=None), freeze_checkpoint=None, freeze_defaults=False, export_path=None, subgraph_filter=None, random_seed=None, disable_packed_input=True): """Exports a InferenceGraph proto with piecewise subgraphs. Sets FLAGS.enable_asserts to False unless user explicitly sets it to True. Args: model_cfg: a Params instance as returned by model_registry.GetParams(modelname, 'Test') or model_params.Model(). model_task_name: The task to generate an inference graph for. Should be None for single-task models. device_options: Device options for the accelerator used for serving. freeze_checkpoint: The checkpoint to load. Loads and freezes the model if given. freeze_defaults: Default initializes the graph and freeze. Useful for early testing of downstream tools without having a checkpoint. export_path: If not None, write the inference graph in ASCII to this path. subgraph_filter: A list of subgraph names. If not None or empty, export only this list of inference subgraphs. random_seed: Fixes the random seed in the exported inference graph. disable_packed_input: Disable packed input for inference writing purposes. Returns: InferenceGraph proto. Raises: ValueError: if the model does not support the listed subgraphs. """ assert issubclass(model_cfg.cls, base_model.BaseModel) # Disable assertions unless user explicitly enables it. if FLAGS['enable_asserts'].using_default_value: FLAGS.enable_asserts = False # TODO(laurenzo): Work out how much we need to specify here in terms of # cluster configuration. cls._SetClusterParams(model_cfg.cluster, device_options) # Configure the model. model_cfg.random_seed = random_seed model_cfg.is_inference = True if disable_packed_input: def _DisablePackedInput(task): if (_ParamExists(task, 'encoder') and _ParamExists(task.encoder, 'packed_input')): task.encoder.packed_input = False if (_ParamExists(task, 'decoder') and _ParamExists(task.decoder, 'packed_input')): task.decoder.packed_input = False if issubclass(model_cfg.cls, base_model.MultiTaskModel): for _, task_param in model_cfg.task_params.IterParams(): _DisablePackedInput(task_param) else: _DisablePackedInput(model_cfg.task) tf.logging.info('Model %s params:', model_cfg.name) for line in model_cfg.ToText().split('\n'): tf.logging.info('%s', line) # Instantiate the graph. graph = tf.Graph() with graph.as_default(): tf.random.set_seed(random_seed) cluster = model_cfg.cluster.Instantiate() device = cluster.GetPlacer() tpu_const_scope = _DummyScope() if (IsTpu(device_options) and device_options.var_options == 'AS_CONSTANTS'): # Do not specify devices for variables if we are marking them as # constants. device = '' tpu_const_scope = ConstGuaranteeScope() with cluster, tf.device(device), tpu_const_scope: bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations( device_options) if bfloat16_override: py_utils.UpdateDtype(model_cfg, tf.bfloat16) py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16) # Hard-code TPU-related flags prior to instantiating model. old_enable_asserts = FLAGS.enable_asserts old_xla_device = FLAGS.xla_device if IsTpu(device_options): FLAGS.enable_asserts = False FLAGS.xla_device = 'tpu' # Ensure the global_step variable is created. _ = py_utils.GetOrCreateGlobalStepVar() try: mdl = model_cfg.Instantiate() task = mdl.GetTask(model_task_name) variables_to_restore = ( _MakeVariableDictionary(tf.global_variables()) if not mdl.ema else mdl.ema.variables_to_restore(mdl.variables_for_ema)) if bfloat16_override: saver_var_spec = ( bfloat16_variables .get_saver_spec_for_variables_with_bf16_overrides( variables_to_restore)) else: saver_var_spec = variables_to_restore saver = tf.train.Saver(saver_var_spec) tf.variables_initializer( tf.global_variables(), name='init_all_variables') if IsTpu(device_options) and device_options.gen_init_op: tf.group(tf.tpu.initialize_system(), name='tpu_init_op') inference_graph_proto = inference_graph_pb2.InferenceGraph() subgraphs_proto = task.Inference() if isinstance(subgraphs_proto, dict): subgraphs_proto = ConvertSubgraphDictToProto(subgraphs_proto) for name, subgraph in subgraphs_proto.subgraphs.items(): if not subgraph_filter or name in subgraph_filter: inference_graph_proto.subgraphs[name].CopyFrom(subgraph) # Add a table init op and global variable init op to the graph. # Tables can be declared anywhere in the graph, so this op has to be # added last. tf.tables_initializer(name='init_all_tables') finally: # Reset TPU-related flags after model instantiation. FLAGS.enable_asserts = old_enable_asserts FLAGS.xla_device = old_xla_device tf.logging.info('Graph contains ops: %r', [op.name for op in graph.get_operations()]) inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def()) # Freezing. if freeze_defaults or freeze_checkpoint: output_op_names = GetOutputOpNames( graph, inference_graph_proto, preserve_colocation_nodes=False) if cls._DeviceSupportsFreezing(device_options): raise ValueError('freeze_checkpoint cannot be used with device ' + device_options.device) if freeze_checkpoint: tf.logging.info('Freezing graph from checkpoint: %s', freeze_checkpoint) graph_def = _FreezeGraphFromCheckpoint(graph, saver, freeze_checkpoint, output_op_names) elif freeze_defaults: tf.logging.info('Default initializing graph and freezing.') graph_def = _FreezeDefaults(graph, output_op_names) else: output_op_names = GetOutputOpNames(graph, inference_graph_proto) # Prune the graph to just the parts we need. # To support restoring, we have to not prune out the restore node. output_op_names.append('init_all_tables') output_op_names.append('init_all_variables') output_op_names.append('save/control_dependency') output_op_names.append('save/restore_all') if IsTpu(device_options) and device_options.gen_init_op: output_op_names.append('tpu_init_op') graph_def = graph.as_graph_def() tf.logging.info('Pruning graph to output ops: %r', output_op_names) graph_def = tf.graph_util.extract_sub_graph(graph_def, output_op_names) if not device_options.retain_device_placement: # Clear the device so that the runtime can choose. tf.logging.info('Clearing device placement for: %s', device_options.device) for node in graph_def.node: node.ClearField('device') for function in graph_def.library.function: for node_def in function.node_def: node_def.ClearField('device') inference_graph_proto.graph_def.CopyFrom(graph_def) if export_path: with tf.io.gfile.GFile(export_path, 'w') as f: f.write(text_format.MessageToString(inference_graph_proto)) return inference_graph_proto
def RestoreIfNeeded(self, sess): """If vars are not initialized, restore from checkpoint.""" assert not self._save_only uninitialized_var_names = self.GetUninitializedVars(sess) # uninitialized_var_names is a list of strings without ":0" suffix. # tf.report_uninitialized_variables returns binary strings. assert all( isinstance(s, six.binary_type) for s in uninitialized_var_names) if not uninitialized_var_names: # All variables are already initialized. return tf.logging.info('Uninitialized var list: %s', uninitialized_var_names) # There should only be uninitialized variables if all variables are # uninitialized. all_var_names = [ six.ensure_binary(v.name[:-2]) for v in tf.global_variables() ] assert (set(uninitialized_var_names) == set(all_var_names) ), sorted(set(all_var_names) - set(uninitialized_var_names)) if self._Restore(sess): # Successfully restored from checkpoint. uninitialized_var_names = self.GetUninitializedVars(sess) assert not uninitialized_var_names, uninitialized_var_names return if (self._params.train.init_from_checkpoint_rules or any(task.params.train.init_from_checkpoint_rules for task in self._model_tasks)): for task in self._model.tasks: tp = task.params.train if tp.init_from_checkpoint_rules: tf.logging.info('OverrideVarsFromCheckpoints %s', tp.init_from_checkpoint_rules) py_utils.OverrideVarsFromCheckpoints( sess, tf.global_variables(), tp.init_from_checkpoint_rules) if self._params.train.init_from_checkpoint_rules: tp = self._params.train tf.logging.info('OverrideVarsFromCheckpoints %s', tp.init_from_checkpoint_rules) py_utils.OverrideVarsFromCheckpoints( sess, tf.global_variables(), tp.init_from_checkpoint_rules) uninitialized_var_names = self.GetUninitializedVars(sess) if not uninitialized_var_names: return tf.logging.info('Remaining uninitialized vars: %s', uninitialized_var_names) # Need to retrieve vars, removing ":0" suffix from names. uninitialized_vars = [ v for v in tf.global_variables() if six.ensure_binary(v.name[:-2]) in uninitialized_var_names ] tf.logging.info('Initialize variables: %s', sorted([v.name[:-2] for v in uninitialized_vars])) sess.run(tf.variables_initializer(uninitialized_vars))
def Export(cls, model_cfg, model_task_name=None, device_options=InferenceDeviceOptions( device='', retain_device_placement=False, var_options=None, gen_init_op=True, dtype_override=None, fprop_dtype_override=None), freeze_checkpoint=None, freeze_defaults=False, export_path=None, subgraph_filter=None, random_seed=None, disable_packed_input=True, prune_graph=True, export_graph_collections=False): """Exports a InferenceGraph proto with piecewise subgraphs. Sets FLAGS.enable_asserts to False unless user explicitly sets it to True. Note: Enable FLAGS.pin_vars_to_cpu (default false) to make weight-sharing and multi-core inference on TPUs work properly. Args: model_cfg: a Params instance as returned by model_registry.GetParams(modelname, 'Test') or model_params.Model(). model_task_name: The task to generate an inference graph for. Should be None for single-task models. device_options: Device options for the accelerator used for serving. freeze_checkpoint: The checkpoint to load. Loads and freezes the model if given. freeze_defaults: Default initializes the graph and freeze. Useful for early testing of downstream tools without having a checkpoint. export_path: If not None, write the inference graph in ASCII to this path. subgraph_filter: A string or a list of subgraph names. If not None or empty, export only this list of inference subgraphs. random_seed: Fixes the random seed in the exported inference graph. disable_packed_input: Disable packed input for inference writing purposes. prune_graph: If true, prune the graph to just the parts we need. export_graph_collections: If true, export graph collections to the InferenceGraph proto. Returns: InferenceGraph proto. Raises: ValueError: if the model does not support the listed subgraphs. """ if py_utils.IsEagerMode(): raise ValueError('InferenceGraph exporter does not work in Eager mode.') assert issubclass(model_cfg.cls, base_model.BaseModel) if device_options.dtype_override and device_options.fprop_dtype_override: raise ValueError( 'device_options{dtype_override,fprop_dtype_override) can not both be' 'set.') if subgraph_filter and not isinstance(subgraph_filter, (tuple, list)): subgraph_filter = [subgraph_filter] # Disable assertions unless user explicitly enables it. if FLAGS['enable_asserts'].using_default_value: FLAGS.enable_asserts = False # TODO(laurenzo): Work out how much we need to specify here in terms of # cluster configuration. cls._SetClusterParams(model_cfg.cluster, device_options) # Configure the model. model_cfg.random_seed = random_seed model_cfg.is_inference = True if disable_packed_input: def _DisablePackedInput(task): if (_ParamExists(task, 'encoder') and _ParamExists(task.encoder, 'packed_input')): task.encoder.packed_input = False if (_ParamExists(task, 'decoder') and _ParamExists(task.decoder, 'packed_input')): task.decoder.packed_input = False if issubclass(model_cfg.cls, base_model.MultiTaskModel): for _, task_param in model_cfg.task_params.IterParams(): _DisablePackedInput(task_param) else: _DisablePackedInput(model_cfg.task) tf.logging.debug('Model %s params:', model_cfg.name) for line in model_cfg.ToText().split('\n'): tf.logging.debug('%s', line) # Instantiate the graph. graph = tf.Graph() with graph.as_default(): tf.random.set_seed(random_seed) cluster = model_cfg.cluster.Instantiate() device = cluster.GetPlacer() tpu_const_scope = _DummyScope() if (IsTpu(device_options) and device_options.var_options == 'AS_CONSTANTS'): # Do not specify devices for variables if we are marking them as # constants. device = '' tpu_const_scope = ConstGuaranteeScope() with cluster, tf.device(device), tpu_const_scope: bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations( device_options) if bfloat16_override: py_utils.UpdateDtype(model_cfg, tf.bfloat16) py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16) act_bfloat16_override = ShouldForceBfloat16ForActivations( device_options) if act_bfloat16_override: py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16) # Hard-code TPU-related flags prior to instantiating model. old_enable_asserts = FLAGS.enable_asserts old_xla_device = FLAGS.xla_device if IsTpu(device_options): FLAGS.enable_asserts = False FLAGS.xla_device = 'tpu' try: mdl = model_cfg.Instantiate() task = mdl.GetTask(model_task_name) variables_to_restore = ( _MakeVariableDictionary(tf.global_variables()) if not mdl.ema else mdl.ema.variables_to_restore(mdl.variables_for_ema)) if bfloat16_override: saver_var_spec = ( bfloat16_variables .get_saver_spec_for_variables_with_bf16_overrides( variables_to_restore)) # For TPU embedding layers, if the table explicitly specifies the # inference dtype as bfloat16, the variables in the checkpoint must # already be in bfloat16, so we change back to bfloat16 to avoid # dtype mismatch. for var_name in (tpu_embedding_layers.TpuEmbeddingCollection.Get() .inference_with_bfloat16_var_names): saver_var_spec[var_name] = variables_to_restore[var_name] else: saver_var_spec = variables_to_restore saver = tf.train.Saver(saver_var_spec) tf.variables_initializer( tf.global_variables(), name='init_all_variables') if IsTpu(device_options) and device_options.gen_init_op: tf.group(tf.tpu.initialize_system(), name='tpu_init_op') if freeze_checkpoint or freeze_defaults: # Replace variables with tensors using tf.identity in theta before # freezing to avoid the graph referencing types of DT_RESOURCE. def AddIdentityToTheta(layer): # pylint: disable=protected-access layer._private_theta = py_utils.Transform(tf.identity, layer._private_theta) # pylint: enable=protected-access layer.children.Transform(AddIdentityToTheta) AddIdentityToTheta(task) inference_graph_proto = inference_graph_pb2.InferenceGraph() subgraphs_proto = task.Inference() if isinstance(subgraphs_proto, dict): subgraphs_proto = ConvertSubgraphDictToProto(subgraphs_proto) for name, subgraph in subgraphs_proto.subgraphs.items(): if not subgraph_filter or name in subgraph_filter: inference_graph_proto.subgraphs[name].CopyFrom(subgraph) if not inference_graph_proto.subgraphs and subgraph_filter: raise ValueError( f'Subgraph filters {subgraph_filter} filtered out all ' 'subgraphs. Defined subgraphs: ' f'{list(subgraphs_proto.subgraphs.keys())}') # Yes, graph collections are bad, however this seems to be the # easiest way to get this assets registered from # TextFileInitializer. assets_collection = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.ASSET_FILEPATHS) for asset in assets_collection: if asset.op.type == 'Const' and asset.op.get_attr( 'dtype') == tf.dtypes.string: constant_value = asset.op.get_attr('value') if constant_value.string_val: tf.logging.info('Found asset file_path: %s', constant_value.string_val[0]) asset_file_def = inference_graph_proto.asset_file_def.add() asset_file_def.tensor_info.name = asset.name asset_file_def.filename = constant_value.string_val[0] # Add a table init op and global variable init op to the graph. # Tables can be declared anywhere in the graph, so this op has to be # added last. tf.tables_initializer(name='init_all_tables') finally: # Reset TPU-related flags after model instantiation. FLAGS.enable_asserts = old_enable_asserts FLAGS.xla_device = old_xla_device tf.logging.info('Graph contains ops: %r', [op.name for op in graph.get_operations()]) # Collection defs if not tf.executing_eagerly(): if export_graph_collections: meta_graph = tf.train.export_meta_graph(graph=graph) for key in meta_graph.collection_def: tf.logging.info('copying collection %s', key) inference_graph_proto.collection_def[key].CopyFrom( meta_graph.collection_def[key]) else: tf.logging.warning('Not exporting collection defs ' 'since operating in eager mode.') # Freezing. if freeze_defaults or freeze_checkpoint: output_op_names = GetOutputOpNames( graph, inference_graph_proto, preserve_colocation_nodes=False, preserve_saver_restore_nodes=False) if cls._DeviceSupportsFreezing(device_options): raise ValueError('freeze_checkpoint cannot be used with device ' + device_options.device) if freeze_checkpoint: tf.logging.info('Freezing graph from checkpoint: %s', freeze_checkpoint) graph_def = _FreezeGraphFromCheckpoint(graph, saver, freeze_checkpoint, output_op_names) elif freeze_defaults: tf.logging.info('Default initializing graph and freezing.') graph_def = _FreezeDefaults(graph, output_op_names) else: inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def()) graph_def = graph.as_graph_def() if prune_graph: output_op_names = GetOutputOpNames(graph, inference_graph_proto) # Prune the graph to just the parts we need. # To support restoring, we have to not prune out the restore node. output_op_names.append('init_all_tables') output_op_names.append('init_all_variables') output_op_names.append('save/control_dependency') output_op_names.append('save/restore_all') if IsTpu(device_options) and device_options.gen_init_op: output_op_names.append('tpu_init_op') tf.logging.info('Pruning graph to output ops: %r', output_op_names) graph_def = tf.compat.v1.graph_util.extract_sub_graph( graph_def, output_op_names) if not device_options.retain_device_placement: # Clear the device so that the runtime can choose. tf.logging.info('Clearing device placement for: %s', device_options.device) for node in graph_def.node: node.ClearField('device') for function in graph_def.library.function: for node_def in function.node_def: node_def.ClearField('device') inference_graph_proto.graph_def.CopyFrom(graph_def) if export_path: with tf.io.gfile.GFile(export_path, 'w') as f: f.write(text_format.MessageToString(inference_graph_proto)) return inference_graph_proto