def Restore(self, sess, force_reinitialize=False): """Restore from latest checkpoint if available, or initialize.""" # Try and restore from the latest checkpoint. if self._RestoreFromLatestCheckpoint(sess): # Successfully restored from checkpoint. uninitialized_var_names = self._GetUninitializedVarNames(sess) assert not uninitialized_var_names, uninitialized_var_names return # Otherwise we need to initialize. uninitialized_var_names = self._GetUninitializedVarNames(sess) tf.logging.info('Uninitialized var list: %s', uninitialized_var_names) if not force_reinitialize: # There should only be uninitialized variables if all variables are # uninitialized - with the exception of global_step due to # RestoreGlobalStepIfNeeded in the _LoopEnqueue of TrainerTpu. all_var_names = [ six.ensure_binary(v.name[:-2]) for v in tf.global_variables() ] already_initialized_vars = (set(all_var_names) - set(uninitialized_var_names)) already_initialized_vars.discard(b'global_step') assert not already_initialized_vars, ( 'Already initialized vars: %s' % sorted(already_initialized_vars)) # At this point all variables are uninitialized, so it is safe to run a # global initializer. sess.run(self._init_op) tf.logging.info('Initialized all vars.') # TODO(b/160786085): Move this logic into Overriding vars logic itself, # which requires refactoring things out of py_utils to avoid circular deps. def _ResolveCkptPath(ckpt_rules): return {GetSpecificCheckpoint(k): v for k, v in ckpt_rules.items()} # Restore specific variables based on init_from_checkpoint_rules. for task in self._model.tasks: tp = task.params.train if tp.init_from_checkpoint_rules: rules = _ResolveCkptPath(tp.init_from_checkpoint_rules) tf.logging.info('OverrideVarsFromCheckpoints %s', rules) py_utils.OverrideVarsFromCheckpoints(sess, tf.global_variables(), rules) if self._params.train.init_from_checkpoint_rules: tp = self._params.train rules = _ResolveCkptPath(tp.init_from_checkpoint_rules) tf.logging.info('OverrideVarsFromCheckpoints %s', rules) py_utils.OverrideVarsFromCheckpoints(sess, tf.global_variables(), rules)
def __init__(self, train_dir, model, train_params=None, save_only=False): """Initialize Checkpointer. Args: train_dir: Training directory for saving checkpoints. model: A BaseModel instance or None. train_params: If specified, use these training params instead of those in the `model`. save_only: This checkpointer is only intended for saving checkpoints. """ self._train_dir = train_dir self._save_only = save_only self._save_path = os.path.join(self._train_dir, 'ckpt') if train_params: self._train_params = train_params self._model = None else: assert model self._train_params = model.params.train self._model = model if not self._save_only: self._params = model.params self._model_tasks = model.tasks self._model = model self._next_checkpoint_seconds = 0 self._save_interval_seconds = self._train_params.save_interval_seconds self._saver = self._GetSaver() self._uninitialized_vars = tf.report_uninitialized_variables( tf.global_variables())
def Restore(self, sess, force_reinitialize=False): """Restore from latest checkpoint if available, or initialize.""" # Try and restore from the latest checkpoint. if self._RestoreFromLatestCheckpoint(sess): # Successfully restored from checkpoint. uninitialized_var_names = self._GetUninitializedVarNames(sess) assert not uninitialized_var_names, uninitialized_var_names return # Otherwise we need to initialize. uninitialized_var_names = self._GetUninitializedVarNames(sess) tf.logging.info('Uninitialized var list: %s', uninitialized_var_names) if not force_reinitialize: # There should only be uninitialized variables if all variables are # uninitialized - with the exception of global_step due to # RestoreGlobalStepIfNeeded in the _LoopEnqueue of TrainerTpu. all_var_names = [ six.ensure_binary(v.name[:-2]) for v in tf.global_variables() ] already_initialized_vars = ( set(all_var_names) - set(uninitialized_var_names)) already_initialized_vars.discard(b'global_step') assert not already_initialized_vars, ('Already initialized vars: %s' % sorted(already_initialized_vars)) # At this point all variables are uninitialized, so it is safe to run a # global initializer. sess.run(self._init_op) tf.logging.info('Initialized all vars.') if self._restore_fns: for fn in self._restore_fns: fn(sess) tf.logging.info('Restored vars using checkpoint rules.')
def Restore(self, sess, force_reinitialize=False): """Restore from latest checkpoint if available, or initialize.""" # Try and restore from the latest checkpoint. if self._RestoreFromLatestCheckpoint(sess): # Successfully restored from checkpoint. uninitialized_var_names = self._GetUninitializedVarNames(sess) assert not uninitialized_var_names, uninitialized_var_names return # Otherwise we need to initialize. uninitialized_var_names = self._GetUninitializedVarNames(sess) tf.logging.info('Uninitialized var list: %s', uninitialized_var_names) if not force_reinitialize: # There should only be uninitialized variables if all variables are # uninitialized - with the exception of global_step due to # RestoreGlobalStepIfNeeded in the _LoopEnqueue of TrainerTpu. all_var_names = [ six.ensure_binary(v.name[:-2]) for v in tf.global_variables() ] already_initialized_vars = (set(all_var_names) - set(uninitialized_var_names)) already_initialized_vars.discard(b'global_step') assert not already_initialized_vars, ( 'Already initialized vars: %s' % sorted(already_initialized_vars)) # At this point all variables are uninitialized, so it is safe to run a # global initializer. sess.run(tf.global_variables_initializer()) tf.logging.info('Initialized all vars.') # Restore specific variables based on init_from_checkpoint_rules. for task in self._model.tasks: tp = task.params.train if tp.init_from_checkpoint_rules: tf.logging.info('OverrideVarsFromCheckpoints %s', tp.init_from_checkpoint_rules) py_utils.OverrideVarsFromCheckpoints( sess, tf.global_variables(), tp.init_from_checkpoint_rules) if self._params.train.init_from_checkpoint_rules: tp = self._params.train tf.logging.info('OverrideVarsFromCheckpoints %s', tp.init_from_checkpoint_rules) py_utils.OverrideVarsFromCheckpoints(sess, tf.global_variables(), tp.init_from_checkpoint_rules)
def __init__(self, train_dir, models, init_op=None, train_params=None, save_only=False): """Initialize Checkpointer. Args: train_dir: Training directory for saving checkpoints. models: One or a list of BaseModel instances. Cannot be empty. If there are more than one models and `train_params` is None, the save intervals will be only determined by the first model. init_op: The initialize variables op. If unset, it will call tf.global_variables_initializer(). train_params: If specified, use these training params instead of those in the `model`. save_only: This checkpointer is only intended for saving checkpoints. """ self._train_dir = train_dir self._save_only = save_only if init_op: self._init_op = init_op else: self._init_op = tf.global_variables_initializer() self._save_path = os.path.join(self._train_dir, 'ckpt') if not isinstance(models, (list, tuple)): models = [models] self._models = models if train_params: self._train_params = train_params else: self._train_params = models[0].params.train self._next_checkpoint_seconds = 0 self._save_interval_seconds = self._train_params.save_interval_seconds self._save_interval_steps = self._train_params.save_interval_steps self._prev_ckpt_step = None self._saver = self._GetSaver() if not py_utils.IsEagerMode(): self._uninitialized_vars = tf.report_uninitialized_variables( tf.global_variables()) self._BuildInitFromCheckpointRules()
def _WrapNonLingvoVars(dest_layer: base_layer.BaseLayer, variables: Collection[tf.Variable], trainable_variables: Collection[tf.Variable] = ()): """Adds variables to the given lingvo layer and appropriate graph collections. This function helps wrap variables created outside of lingvo so they are correctly handled by lingvo's trainer and checkpointer. It does the following: - makes all `variables` trackable through `dest_layer.vars`; - ensures `variables` are in the `tf.global_variables()` graph collection so the trainer can initialize them; - adds the `trainable_variables` subset to the `tf.trainable_variables()` graph collection, so they are visible to the learner (i.e. can be trained). Args: dest_layer: Lingvo layer to add the `variables` to. variables: The non-lingvo variables to wrap. trainable_variables: The subset of `variables` to ensure are trainable. """ global_collection = set(tf.global_variables()) for v in variables: assert v in global_collection name = v.name.split(':')[0] # pylint: disable=protected-access dest_layer._private_vars[name] = v with tf.device(v.device): dest_layer._private_theta[name] = tf.identity(v) # pylint: enable=protected-access trainable_collection = set(tf.trainable_variables()) for v in trainable_variables: if v not in trainable_collection: tf.logging.warning( 'Wrapped var %s not in trainable collection; adding it.', v.name) tf.compat.v1.add_to_collection( tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, v)
def testAccumulator(self): # testAccumulator compares # - explicit averaging of independently computed var_grads1 and # var_grads2, # - Accumulator(SGD) optimizer effectively doing this over 2 steps. np.random.seed(12345) np_input1 = np.random.normal(0.1, 0.5, [2, 4, 3]) np.random.seed(12346) np_input2 = np.random.normal(0.1, 0.5, [2, 4, 3]) with self.session(use_gpu=True, graph=tf.Graph()) as sess: tf.random.set_seed(123456) params = layers.ProjectionLayer.Params() params.name = 'proj' params.dtype = tf.float64 params.input_dim = 3 params.output_dim = 2 params.params_init = py_utils.WeightInit.Gaussian(0.01, 123456) params.batch_norm = False proj_layer = layers.ProjectionLayer(params) inputs1 = tf.placeholder(shape=[2, 4, 3], dtype=tf.float64) in_padding1 = tf.zeros([2, 4, 1], dtype=tf.float64) inputs2 = tf.placeholder(shape=[2, 4, 3], dtype=tf.float64) in_padding2 = tf.zeros([2, 4, 1], dtype=tf.float64) output1 = proj_layer.FPropDefaultTheta(inputs1, in_padding1) output2 = proj_layer.FPropDefaultTheta(inputs2, in_padding2) loss1 = tf.reduce_sum(output1) loss2 = tf.reduce_sum(output2) var_grads1 = py_utils.ComputeGradients(loss1, proj_layer.vars) var_grads2 = py_utils.ComputeGradients(loss2, proj_layer.vars) op = optimizer.SGD.Params() opt = op.Instantiate() lr = 1e-1 with tf.control_dependencies([loss1, loss2]): var_update_op1 = opt.Apply( lr, py_utils.ApplyGradMultiplier(var_grads1, 1. / 2.)) with tf.control_dependencies([var_update_op1]): var_update_op2 = opt.Apply( lr, py_utils.ApplyGradMultiplier(var_grads2, 1. / 2.)) self.evaluate(tf.global_variables_initializer()) vars1 = self.evaluate(proj_layer.vars.Flatten()) loss1_1, grads1_1, loss1_2, grads1_2 = sess.run( [ loss1, var_grads1.Transform(tuple), loss2, var_grads2.Transform(tuple) ], feed_dict={ inputs1: np_input1, inputs2: np_input2, }, ) sess.run([var_update_op2], feed_dict={ inputs1: np_input1, inputs2: np_input2, }) vars1_1 = self.evaluate(proj_layer.vars.Flatten()) with self.session(use_gpu=True, graph=tf.Graph()) as sess: tf.random.set_seed(123456) params = layers.ProjectionLayer.Params() params.name = 'proj' params.dtype = tf.float64 params.input_dim = 3 params.output_dim = 2 params.params_init = py_utils.WeightInit.Gaussian(0.01, 123456) params.batch_norm = False proj_layer = layers.ProjectionLayer(params) in_padding1 = tf.zeros([2, 4, 1], dtype=tf.float64) inputs1 = tf.placeholder(shape=[2, 4, 3], dtype=tf.float64) output1 = proj_layer.FPropDefaultTheta(inputs1, in_padding1) loss = tf.reduce_sum(output1) var_grads = py_utils.ComputeGradients(loss, proj_layer.vars) op = optimizer.Accumulator.Params().Set( accum_steps=2, dtype=tf.float64, optimizer_tpl=optimizer.SGD.Params()) opt = op.Instantiate() lr = 1e-1 with cluster_factory.ForTestingWorker(add_summary=True): var_update_op = opt.Apply(lr, var_grads) increment_global_step_op = tf.assign_add( py_utils.GetOrCreateGlobalStepVar(), 1) self.evaluate(tf.global_variables_initializer()) vars2 = self.evaluate(proj_layer.vars.Flatten()) loss2_1, grads2_1 = sess.run( [loss, var_grads.Transform(tuple)], feed_dict={ inputs1: np_input1, }) loss2_2, grads2_2 = sess.run( [loss, var_grads.Transform(tuple)], feed_dict={ inputs1: np_input2, }) acc_0 = self.evaluate([ v for v in tf.global_variables() if 'grad_accumulator' in v.name ])[0] sess.run([var_update_op], feed_dict={ inputs1: np_input1, }) acc_1 = self.evaluate([ v for v in tf.global_variables() if 'grad_accumulator' in v.name ])[0] vars2_intermediate = self.evaluate(proj_layer.vars.Flatten()) self.evaluate(increment_global_step_op) sess.run([var_update_op], feed_dict={ inputs1: np_input2, }) acc_2 = self.evaluate([ v for v in tf.global_variables() if 'grad_accumulator' in v.name ])[0] vars2_1 = self.evaluate(proj_layer.vars.Flatten()) summary = tf.Summary.FromString( self.evaluate(tf.summary.merge_all())) tf.logging.info(f'summary: {summary}') self.assertEqual(summary.value[0].tag, 'sgd_lr') self.assertAllClose(vars1, vars2) self.assertAllClose(acc_0, np.zeros_like(acc_0)) self.assertAllClose(acc_1, grads2_1['w'][1]) self.assertAllClose(acc_2, np.zeros_like(acc_0)) self.assertAllClose(loss1_1, loss2_1) self.assertAllClose(loss1_2, loss2_2) self.assertAllClose(grads1_1, grads2_1) self.assertAllClose(grads1_2, grads2_2) self.assertAllClose(vars1, vars2_intermediate) self.assertAllClose(vars2[0], grads2_1['w'][0]) self.assertAllClose(vars2[0], grads2_2['w'][0]) self.assertAllClose( vars1[0] - 0.5 * lr * (grads1_1['w'][1] + grads1_2['w'][1]), vars1_1[0]) self.assertAllClose( vars2[0] - 0.5 * lr * (grads2_1['w'][1] + grads2_2['w'][1]), vars2_1[0]) self.assertAllClose(vars2, vars2_intermediate) self.assertAllClose(vars1_1, vars2_1)
def Export(cls, model_cfg, model_task_name=None, device_options=InferenceDeviceOptions( device='', retain_device_placement=False, var_options=None, gen_init_op=True, dtype_override=None), freeze_checkpoint=None, freeze_defaults=False, export_path=None, subgraph_filter=None, random_seed=None, disable_packed_input=True): """Exports a InferenceGraph proto with piecewise subgraphs. Sets FLAGS.enable_asserts to False unless user explicitly sets it to True. Args: model_cfg: a Params instance as returned by model_registry.GetParams(modelname, 'Test') or model_params.Model(). model_task_name: The task to generate an inference graph for. Should be None for single-task models. device_options: Device options for the accelerator used for serving. freeze_checkpoint: The checkpoint to load. Loads and freezes the model if given. freeze_defaults: Default initializes the graph and freeze. Useful for early testing of downstream tools without having a checkpoint. export_path: If not None, write the inference graph in ASCII to this path. subgraph_filter: A list of subgraph names. If not None or empty, export only this list of inference subgraphs. random_seed: Fixes the random seed in the exported inference graph. disable_packed_input: Disable packed input for inference writing purposes. Returns: InferenceGraph proto. Raises: ValueError: if the model does not support the listed subgraphs. """ assert issubclass(model_cfg.cls, base_model.BaseModel) # Disable assertions unless user explicitly enables it. if FLAGS['enable_asserts'].using_default_value: FLAGS.enable_asserts = False # TODO(laurenzo): Work out how much we need to specify here in terms of # cluster configuration. cls._SetClusterParams(model_cfg.cluster, device_options) # Configure the model. model_cfg.random_seed = random_seed model_cfg.is_inference = True if disable_packed_input: def _DisablePackedInput(task): if (_ParamExists(task, 'encoder') and _ParamExists(task.encoder, 'packed_input')): task.encoder.packed_input = False if (_ParamExists(task, 'decoder') and _ParamExists(task.decoder, 'packed_input')): task.decoder.packed_input = False if issubclass(model_cfg.cls, base_model.MultiTaskModel): for _, task_param in model_cfg.task_params.IterParams(): _DisablePackedInput(task_param) else: _DisablePackedInput(model_cfg.task) tf.logging.info('Model %s params:', model_cfg.name) for line in model_cfg.ToText().split('\n'): tf.logging.info('%s', line) # Instantiate the graph. graph = tf.Graph() with graph.as_default(): tf.random.set_seed(random_seed) cluster = model_cfg.cluster.Instantiate() device = cluster.GetPlacer() tpu_const_scope = _DummyScope() if (IsTpu(device_options) and device_options.var_options == 'AS_CONSTANTS'): # Do not specify devices for variables if we are marking them as # constants. device = '' tpu_const_scope = ConstGuaranteeScope() with cluster, tf.device(device), tpu_const_scope: bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations( device_options) if bfloat16_override: py_utils.UpdateDtype(model_cfg, tf.bfloat16) py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16) # Hard-code TPU-related flags prior to instantiating model. old_enable_asserts = FLAGS.enable_asserts old_xla_device = FLAGS.xla_device if IsTpu(device_options): FLAGS.enable_asserts = False FLAGS.xla_device = 'tpu' # Ensure the global_step variable is created. _ = py_utils.GetOrCreateGlobalStepVar() try: mdl = model_cfg.Instantiate() task = mdl.GetTask(model_task_name) variables_to_restore = ( _MakeVariableDictionary(tf.global_variables()) if not mdl.ema else mdl.ema.variables_to_restore(mdl.variables_for_ema)) if bfloat16_override: saver_var_spec = ( bfloat16_variables .get_saver_spec_for_variables_with_bf16_overrides( variables_to_restore)) else: saver_var_spec = variables_to_restore saver = tf.train.Saver(saver_var_spec) tf.variables_initializer( tf.global_variables(), name='init_all_variables') if IsTpu(device_options) and device_options.gen_init_op: tf.group(tf.tpu.initialize_system(), name='tpu_init_op') inference_graph_proto = inference_graph_pb2.InferenceGraph() subgraphs_proto = task.Inference() if isinstance(subgraphs_proto, dict): subgraphs_proto = ConvertSubgraphDictToProto(subgraphs_proto) for name, subgraph in subgraphs_proto.subgraphs.items(): if not subgraph_filter or name in subgraph_filter: inference_graph_proto.subgraphs[name].CopyFrom(subgraph) # Add a table init op and global variable init op to the graph. # Tables can be declared anywhere in the graph, so this op has to be # added last. tf.tables_initializer(name='init_all_tables') finally: # Reset TPU-related flags after model instantiation. FLAGS.enable_asserts = old_enable_asserts FLAGS.xla_device = old_xla_device tf.logging.info('Graph contains ops: %r', [op.name for op in graph.get_operations()]) inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def()) # Freezing. if freeze_defaults or freeze_checkpoint: output_op_names = GetOutputOpNames( graph, inference_graph_proto, preserve_colocation_nodes=False) if cls._DeviceSupportsFreezing(device_options): raise ValueError('freeze_checkpoint cannot be used with device ' + device_options.device) if freeze_checkpoint: tf.logging.info('Freezing graph from checkpoint: %s', freeze_checkpoint) graph_def = _FreezeGraphFromCheckpoint(graph, saver, freeze_checkpoint, output_op_names) elif freeze_defaults: tf.logging.info('Default initializing graph and freezing.') graph_def = _FreezeDefaults(graph, output_op_names) else: output_op_names = GetOutputOpNames(graph, inference_graph_proto) # Prune the graph to just the parts we need. # To support restoring, we have to not prune out the restore node. output_op_names.append('init_all_tables') output_op_names.append('init_all_variables') output_op_names.append('save/control_dependency') output_op_names.append('save/restore_all') if IsTpu(device_options) and device_options.gen_init_op: output_op_names.append('tpu_init_op') graph_def = graph.as_graph_def() tf.logging.info('Pruning graph to output ops: %r', output_op_names) graph_def = tf.graph_util.extract_sub_graph(graph_def, output_op_names) if not device_options.retain_device_placement: # Clear the device so that the runtime can choose. tf.logging.info('Clearing device placement for: %s', device_options.device) for node in graph_def.node: node.ClearField('device') for function in graph_def.library.function: for node_def in function.node_def: node_def.ClearField('device') inference_graph_proto.graph_def.CopyFrom(graph_def) if export_path: with tf.io.gfile.GFile(export_path, 'w') as f: f.write(text_format.MessageToString(inference_graph_proto)) return inference_graph_proto
def __init__(self, train_dir, model, init_op=None, train_params=None, save_only=False): """Initialize Checkpointer. Args: train_dir: Training directory for saving checkpoints. model: A BaseModel instance or None. init_op: The initialize variables op. If unset, it will call tf.global_variables_initializer(). train_params: If specified, use these training params instead of those in the `model`. save_only: This checkpointer is only intended for saving checkpoints. """ self._train_dir = train_dir self._save_only = save_only if init_op: self._init_op = init_op else: self._init_op = tf.global_variables_initializer() self._save_path = os.path.join(self._train_dir, 'ckpt') if train_params: self._train_params = train_params self._model = None else: assert model self._train_params = model.params.train self._model = model if self._save_only: self._params = None else: self._params = model.params self._model_tasks = model.tasks self._model = model self._next_checkpoint_seconds = 0 self._save_interval_seconds = self._train_params.save_interval_seconds self._saver = self._GetSaver() self._uninitialized_vars = tf.report_uninitialized_variables( tf.global_variables()) # TODO(b/160786085): Move this logic into Overriding vars logic itself, # which requires refactoring things out of py_utils to avoid circular deps. def _ResolveCkptPath(ckpt_rules): return {GetSpecificCheckpoint(k): v for k, v in ckpt_rules.items()} self._restore_fns = [] # Add graph nodes to restore specific variables based on # init_from_checkpoint_rules. # TODO(b/159267006): Move this back to Restore(). if self._model: for task in self._model.tasks: tp = task.params.train if tp.init_from_checkpoint_rules: rules = _ResolveCkptPath(tp.init_from_checkpoint_rules) tf.logging.info('OverrideVarsFromCheckpoints %s', rules) fn = py_utils.OverrideVarsFromCheckpoints(tf.global_variables(), rules) self._restore_fns.append(fn) if self._params and self._params.train.init_from_checkpoint_rules: tp = self._params.train rules = _ResolveCkptPath(tp.init_from_checkpoint_rules) tf.logging.info('OverrideVarsFromCheckpoints %s', rules) fn = py_utils.OverrideVarsFromCheckpoints(tf.global_variables(), rules) self._restore_fns.append(fn)
def RestoreIfNeeded(self, sess): """If vars are not initialized, restore from checkpoint.""" assert not self._save_only uninitialized_var_names = self.GetUninitializedVars(sess) # uninitialized_var_names is a list of strings without ":0" suffix. # tf.report_uninitialized_variables returns binary strings. assert all( isinstance(s, six.binary_type) for s in uninitialized_var_names) if not uninitialized_var_names: # All variables are already initialized. return tf.logging.info('Uninitialized var list: %s', uninitialized_var_names) # There should only be uninitialized variables if all variables are # uninitialized. all_var_names = [ six.ensure_binary(v.name[:-2]) for v in tf.global_variables() ] assert (set(uninitialized_var_names) == set(all_var_names) ), sorted(set(all_var_names) - set(uninitialized_var_names)) if self._Restore(sess): # Successfully restored from checkpoint. uninitialized_var_names = self.GetUninitializedVars(sess) assert not uninitialized_var_names, uninitialized_var_names return if (self._params.train.init_from_checkpoint_rules or any(task.params.train.init_from_checkpoint_rules for task in self._model_tasks)): for task in self._model.tasks: tp = task.params.train if tp.init_from_checkpoint_rules: tf.logging.info('OverrideVarsFromCheckpoints %s', tp.init_from_checkpoint_rules) py_utils.OverrideVarsFromCheckpoints( sess, tf.global_variables(), tp.init_from_checkpoint_rules) if self._params.train.init_from_checkpoint_rules: tp = self._params.train tf.logging.info('OverrideVarsFromCheckpoints %s', tp.init_from_checkpoint_rules) py_utils.OverrideVarsFromCheckpoints( sess, tf.global_variables(), tp.init_from_checkpoint_rules) uninitialized_var_names = self.GetUninitializedVars(sess) if not uninitialized_var_names: return tf.logging.info('Remaining uninitialized vars: %s', uninitialized_var_names) # Need to retrieve vars, removing ":0" suffix from names. uninitialized_vars = [ v for v in tf.global_variables() if six.ensure_binary(v.name[:-2]) in uninitialized_var_names ] tf.logging.info('Initialize variables: %s', sorted([v.name[:-2] for v in uninitialized_vars])) sess.run(tf.variables_initializer(uninitialized_vars))
def Export(cls, model_cfg, model_task_name=None, device_options=InferenceDeviceOptions( device='', retain_device_placement=False, var_options=None, gen_init_op=True, dtype_override=None, fprop_dtype_override=None), freeze_checkpoint=None, freeze_defaults=False, export_path=None, subgraph_filter=None, random_seed=None, disable_packed_input=True, prune_graph=True, export_graph_collections=False): """Exports a InferenceGraph proto with piecewise subgraphs. Sets FLAGS.enable_asserts to False unless user explicitly sets it to True. Note: Enable FLAGS.pin_vars_to_cpu (default false) to make weight-sharing and multi-core inference on TPUs work properly. Args: model_cfg: a Params instance as returned by model_registry.GetParams(modelname, 'Test') or model_params.Model(). model_task_name: The task to generate an inference graph for. Should be None for single-task models. device_options: Device options for the accelerator used for serving. freeze_checkpoint: The checkpoint to load. Loads and freezes the model if given. freeze_defaults: Default initializes the graph and freeze. Useful for early testing of downstream tools without having a checkpoint. export_path: If not None, write the inference graph in ASCII to this path. subgraph_filter: A string or a list of subgraph names. If not None or empty, export only this list of inference subgraphs. random_seed: Fixes the random seed in the exported inference graph. disable_packed_input: Disable packed input for inference writing purposes. prune_graph: If true, prune the graph to just the parts we need. export_graph_collections: If true, export graph collections to the InferenceGraph proto. Returns: InferenceGraph proto. Raises: ValueError: if the model does not support the listed subgraphs. """ if py_utils.IsEagerMode(): raise ValueError('InferenceGraph exporter does not work in Eager mode.') assert issubclass(model_cfg.cls, base_model.BaseModel) if device_options.dtype_override and device_options.fprop_dtype_override: raise ValueError( 'device_options{dtype_override,fprop_dtype_override) can not both be' 'set.') if subgraph_filter and not isinstance(subgraph_filter, (tuple, list)): subgraph_filter = [subgraph_filter] # Disable assertions unless user explicitly enables it. if FLAGS['enable_asserts'].using_default_value: FLAGS.enable_asserts = False # TODO(laurenzo): Work out how much we need to specify here in terms of # cluster configuration. cls._SetClusterParams(model_cfg.cluster, device_options) # Configure the model. model_cfg.random_seed = random_seed model_cfg.is_inference = True if disable_packed_input: def _DisablePackedInput(task): if (_ParamExists(task, 'encoder') and _ParamExists(task.encoder, 'packed_input')): task.encoder.packed_input = False if (_ParamExists(task, 'decoder') and _ParamExists(task.decoder, 'packed_input')): task.decoder.packed_input = False if issubclass(model_cfg.cls, base_model.MultiTaskModel): for _, task_param in model_cfg.task_params.IterParams(): _DisablePackedInput(task_param) else: _DisablePackedInput(model_cfg.task) tf.logging.debug('Model %s params:', model_cfg.name) for line in model_cfg.ToText().split('\n'): tf.logging.debug('%s', line) # Instantiate the graph. graph = tf.Graph() with graph.as_default(): tf.random.set_seed(random_seed) cluster = model_cfg.cluster.Instantiate() device = cluster.GetPlacer() tpu_const_scope = _DummyScope() if (IsTpu(device_options) and device_options.var_options == 'AS_CONSTANTS'): # Do not specify devices for variables if we are marking them as # constants. device = '' tpu_const_scope = ConstGuaranteeScope() with cluster, tf.device(device), tpu_const_scope: bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations( device_options) if bfloat16_override: py_utils.UpdateDtype(model_cfg, tf.bfloat16) py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16) act_bfloat16_override = ShouldForceBfloat16ForActivations( device_options) if act_bfloat16_override: py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16) # Hard-code TPU-related flags prior to instantiating model. old_enable_asserts = FLAGS.enable_asserts old_xla_device = FLAGS.xla_device if IsTpu(device_options): FLAGS.enable_asserts = False FLAGS.xla_device = 'tpu' try: mdl = model_cfg.Instantiate() task = mdl.GetTask(model_task_name) variables_to_restore = ( _MakeVariableDictionary(tf.global_variables()) if not mdl.ema else mdl.ema.variables_to_restore(mdl.variables_for_ema)) if bfloat16_override: saver_var_spec = ( bfloat16_variables .get_saver_spec_for_variables_with_bf16_overrides( variables_to_restore)) # For TPU embedding layers, if the table explicitly specifies the # inference dtype as bfloat16, the variables in the checkpoint must # already be in bfloat16, so we change back to bfloat16 to avoid # dtype mismatch. for var_name in (tpu_embedding_layers.TpuEmbeddingCollection.Get() .inference_with_bfloat16_var_names): saver_var_spec[var_name] = variables_to_restore[var_name] else: saver_var_spec = variables_to_restore saver = tf.train.Saver(saver_var_spec) tf.variables_initializer( tf.global_variables(), name='init_all_variables') if IsTpu(device_options) and device_options.gen_init_op: tf.group(tf.tpu.initialize_system(), name='tpu_init_op') if freeze_checkpoint or freeze_defaults: # Replace variables with tensors using tf.identity in theta before # freezing to avoid the graph referencing types of DT_RESOURCE. def AddIdentityToTheta(layer): # pylint: disable=protected-access layer._private_theta = py_utils.Transform(tf.identity, layer._private_theta) # pylint: enable=protected-access layer.children.Transform(AddIdentityToTheta) AddIdentityToTheta(task) inference_graph_proto = inference_graph_pb2.InferenceGraph() subgraphs_proto = task.Inference() if isinstance(subgraphs_proto, dict): subgraphs_proto = ConvertSubgraphDictToProto(subgraphs_proto) for name, subgraph in subgraphs_proto.subgraphs.items(): if not subgraph_filter or name in subgraph_filter: inference_graph_proto.subgraphs[name].CopyFrom(subgraph) if not inference_graph_proto.subgraphs and subgraph_filter: raise ValueError( f'Subgraph filters {subgraph_filter} filtered out all ' 'subgraphs. Defined subgraphs: ' f'{list(subgraphs_proto.subgraphs.keys())}') # Yes, graph collections are bad, however this seems to be the # easiest way to get this assets registered from # TextFileInitializer. assets_collection = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.ASSET_FILEPATHS) for asset in assets_collection: if asset.op.type == 'Const' and asset.op.get_attr( 'dtype') == tf.dtypes.string: constant_value = asset.op.get_attr('value') if constant_value.string_val: tf.logging.info('Found asset file_path: %s', constant_value.string_val[0]) asset_file_def = inference_graph_proto.asset_file_def.add() asset_file_def.tensor_info.name = asset.name asset_file_def.filename = constant_value.string_val[0] # Add a table init op and global variable init op to the graph. # Tables can be declared anywhere in the graph, so this op has to be # added last. tf.tables_initializer(name='init_all_tables') finally: # Reset TPU-related flags after model instantiation. FLAGS.enable_asserts = old_enable_asserts FLAGS.xla_device = old_xla_device tf.logging.info('Graph contains ops: %r', [op.name for op in graph.get_operations()]) # Collection defs if not tf.executing_eagerly(): if export_graph_collections: meta_graph = tf.train.export_meta_graph(graph=graph) for key in meta_graph.collection_def: tf.logging.info('copying collection %s', key) inference_graph_proto.collection_def[key].CopyFrom( meta_graph.collection_def[key]) else: tf.logging.warning('Not exporting collection defs ' 'since operating in eager mode.') # Freezing. if freeze_defaults or freeze_checkpoint: output_op_names = GetOutputOpNames( graph, inference_graph_proto, preserve_colocation_nodes=False, preserve_saver_restore_nodes=False) if cls._DeviceSupportsFreezing(device_options): raise ValueError('freeze_checkpoint cannot be used with device ' + device_options.device) if freeze_checkpoint: tf.logging.info('Freezing graph from checkpoint: %s', freeze_checkpoint) graph_def = _FreezeGraphFromCheckpoint(graph, saver, freeze_checkpoint, output_op_names) elif freeze_defaults: tf.logging.info('Default initializing graph and freezing.') graph_def = _FreezeDefaults(graph, output_op_names) else: inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def()) graph_def = graph.as_graph_def() if prune_graph: output_op_names = GetOutputOpNames(graph, inference_graph_proto) # Prune the graph to just the parts we need. # To support restoring, we have to not prune out the restore node. output_op_names.append('init_all_tables') output_op_names.append('init_all_variables') output_op_names.append('save/control_dependency') output_op_names.append('save/restore_all') if IsTpu(device_options) and device_options.gen_init_op: output_op_names.append('tpu_init_op') tf.logging.info('Pruning graph to output ops: %r', output_op_names) graph_def = tf.compat.v1.graph_util.extract_sub_graph( graph_def, output_op_names) if not device_options.retain_device_placement: # Clear the device so that the runtime can choose. tf.logging.info('Clearing device placement for: %s', device_options.device) for node in graph_def.node: node.ClearField('device') for function in graph_def.library.function: for node_def in function.node_def: node_def.ClearField('device') inference_graph_proto.graph_def.CopyFrom(graph_def) if export_path: with tf.io.gfile.GFile(export_path, 'w') as f: f.write(text_format.MessageToString(inference_graph_proto)) return inference_graph_proto