self.MaybeConfigRunDistributed() self.MaybeConfigCloudTpu() self.MaybeLaunchTensorFlow() if FLAGS.job.startswith('evaler_once_'): # E.g., trainer --model=foo.bar.Model --logdir=... # --run_locally=cpu --mode=sync --job=evaler_once_test@65200 self.RunEvalerOnce() return self.StartRunners( self.CreateRunners(FLAGS.job.split(','), FLAGS.logdir)) def main(unused_argv): RunnerManager(FLAGS.model).Start() if __name__ == '__main__': tf.flags.mark_flag_as_required('model') FLAGS(sys.argv, known_only=True) if FLAGS.disable_tf2: tf.disable_v2_behavior() py_utils.SetEagerMode(FLAGS.use_eager) tf.config.run_functions_eagerly(FLAGS.run_functions_eagerly) if FLAGS.enable_tf_data_debug_mode: tf.data.experimental.enable_debug_mode() model_imports.ImportParams(FLAGS.model) FLAGS.unparse_flags() tf.app.run(main)
def GetProgramSchedule(class_key): model_imports.ImportParams(class_key) return _ModelRegistryHelper.GetProgramSchedule(class_key)
def GetParams(class_key, dataset_name): model_imports.ImportParams(class_key) return _ModelRegistryHelper.GetParams(class_key, dataset_name)
def GetClass(class_key): model_imports.ImportParams(class_key) return _ModelRegistryHelper.GetClass(class_key)
def _load_model(self): """ Define and instantiate the computation graph. """ import tensorflow.compat.v1 as tf1 from lingvo import model_registry, model_imports from lingvo.core import cluster_factory from asr.librispeech import Librispeech960Wpm # check and download patched Lingvo ASR decoder _ = self._check_and_download_file( self._LINGVO_CFG["decoder"]["uri"], self._LINGVO_CFG["decoder"]["basename"], self._LINGVO_CFG["path"], "asr" ) # monkey-patch the lingvo.asr.decoder.AsrDecoderBase._ComputeMetrics method with patched method according # to Qin et al from lingvo.tasks.asr import decoder from asr import decoder_patched decoder.AsrDecoderBase._ComputeMetrics = decoder_patched.AsrDecoderBase._ComputeMetrics # pylint: disable=W0212 # check and download Lingvo ASR vocab # vocab_path = self._check_and_download_vocab() vocab_path = self._check_and_download_file( self._LINGVO_CFG["vocab"]["uri"], self._LINGVO_CFG["vocab"]["basename"], self._LINGVO_CFG["path"], "asr" ) # monkey-patch tasks.asr.librispeechLibriSpeech960Wpm class attribute WPM_SYMBOL_TABLE_FILEPATH Librispeech960Wpm.WPM_SYMBOL_TABLE_FILEPATH = vocab_path # register model params model_name = "asr.librispeech.Librispeech960Wpm" model_imports.ImportParams(model_name) params = model_registry._ModelRegistryHelper.GetParams(model_name, "Test") # pylint: disable=W0212 # set random seed parameter if self.random_seed is not None: params.random_seed = self.random_seed # instantiate Lingvo ASR model cluster = cluster_factory.Cluster(params.cluster) with cluster, tf1.device(cluster.GetPlacer()): model = params.Instantiate() task = model.GetTask() # load Qin et al. pretrained model _ = self._check_and_download_file( self._LINGVO_CFG["model_data"]["uri"], self._LINGVO_CFG["model_data"]["basename"], self._LINGVO_CFG["path"], "asr", "model", ) model_index_path = self._check_and_download_file( self._LINGVO_CFG["model_index"]["uri"], self._LINGVO_CFG["model_index"]["basename"], self._LINGVO_CFG["path"], "asr", "model", ) self.sess.run(tf1.global_variables_initializer()) saver = tf1.train.Saver([var for var in tf1.global_variables() if var.name.startswith("librispeech")]) saver.restore(self.sess, os.path.splitext(model_index_path)[0]) # set 'enable_asserts'-flag to False (Note: this flag ensures correct GPU support) tf1.flags.FLAGS.enable_asserts = False return model, task, cluster
def import_params(model_name: str, require_success: bool = True) -> Any: """Attempts to only import files that may contain the model.""" return lingvo_model_imports.ImportParams(model_name, _TASK_ROOT, require_success=require_success)