Esempio n. 1
0
        self.MaybeConfigRunDistributed()
        self.MaybeConfigCloudTpu()
        self.MaybeLaunchTensorFlow()

        if FLAGS.job.startswith('evaler_once_'):
            # E.g., trainer --model=foo.bar.Model --logdir=...
            # --run_locally=cpu --mode=sync --job=evaler_once_test@65200
            self.RunEvalerOnce()
            return

        self.StartRunners(
            self.CreateRunners(FLAGS.job.split(','), FLAGS.logdir))


def main(unused_argv):
    RunnerManager(FLAGS.model).Start()


if __name__ == '__main__':
    tf.flags.mark_flag_as_required('model')
    FLAGS(sys.argv, known_only=True)
    if FLAGS.disable_tf2:
        tf.disable_v2_behavior()
    py_utils.SetEagerMode(FLAGS.use_eager)
    tf.config.run_functions_eagerly(FLAGS.run_functions_eagerly)
    if FLAGS.enable_tf_data_debug_mode:
        tf.data.experimental.enable_debug_mode()
    model_imports.ImportParams(FLAGS.model)
    FLAGS.unparse_flags()
    tf.app.run(main)
Esempio n. 2
0
def GetProgramSchedule(class_key):
  model_imports.ImportParams(class_key)
  return _ModelRegistryHelper.GetProgramSchedule(class_key)
Esempio n. 3
0
def GetParams(class_key, dataset_name):
  model_imports.ImportParams(class_key)
  return _ModelRegistryHelper.GetParams(class_key, dataset_name)
Esempio n. 4
0
def GetClass(class_key):
  model_imports.ImportParams(class_key)
  return _ModelRegistryHelper.GetClass(class_key)
Esempio n. 5
0
    def _load_model(self):
        """
        Define and instantiate the computation graph.
        """
        import tensorflow.compat.v1 as tf1
        from lingvo import model_registry, model_imports
        from lingvo.core import cluster_factory

        from asr.librispeech import Librispeech960Wpm

        # check and download patched Lingvo ASR decoder
        _ = self._check_and_download_file(
            self._LINGVO_CFG["decoder"]["uri"], self._LINGVO_CFG["decoder"]["basename"], self._LINGVO_CFG["path"], "asr"
        )

        # monkey-patch the lingvo.asr.decoder.AsrDecoderBase._ComputeMetrics method with patched method according
        # to Qin et al
        from lingvo.tasks.asr import decoder
        from asr import decoder_patched

        decoder.AsrDecoderBase._ComputeMetrics = decoder_patched.AsrDecoderBase._ComputeMetrics  # pylint: disable=W0212

        # check and download Lingvo ASR vocab
        # vocab_path = self._check_and_download_vocab()
        vocab_path = self._check_and_download_file(
            self._LINGVO_CFG["vocab"]["uri"], self._LINGVO_CFG["vocab"]["basename"], self._LINGVO_CFG["path"], "asr"
        )

        # monkey-patch tasks.asr.librispeechLibriSpeech960Wpm class attribute WPM_SYMBOL_TABLE_FILEPATH
        Librispeech960Wpm.WPM_SYMBOL_TABLE_FILEPATH = vocab_path

        # register model params
        model_name = "asr.librispeech.Librispeech960Wpm"
        model_imports.ImportParams(model_name)
        params = model_registry._ModelRegistryHelper.GetParams(model_name, "Test")  # pylint: disable=W0212

        # set random seed parameter
        if self.random_seed is not None:
            params.random_seed = self.random_seed

        # instantiate Lingvo ASR model
        cluster = cluster_factory.Cluster(params.cluster)
        with cluster, tf1.device(cluster.GetPlacer()):
            model = params.Instantiate()
            task = model.GetTask()

        # load Qin et al. pretrained model
        _ = self._check_and_download_file(
            self._LINGVO_CFG["model_data"]["uri"],
            self._LINGVO_CFG["model_data"]["basename"],
            self._LINGVO_CFG["path"],
            "asr",
            "model",
        )
        model_index_path = self._check_and_download_file(
            self._LINGVO_CFG["model_index"]["uri"],
            self._LINGVO_CFG["model_index"]["basename"],
            self._LINGVO_CFG["path"],
            "asr",
            "model",
        )
        self.sess.run(tf1.global_variables_initializer())
        saver = tf1.train.Saver([var for var in tf1.global_variables() if var.name.startswith("librispeech")])
        saver.restore(self.sess, os.path.splitext(model_index_path)[0])

        # set 'enable_asserts'-flag to False (Note: this flag ensures correct GPU support)
        tf1.flags.FLAGS.enable_asserts = False

        return model, task, cluster
Esempio n. 6
0
def import_params(model_name: str, require_success: bool = True) -> Any:
    """Attempts to only import files that may contain the model."""
    return lingvo_model_imports.ImportParams(model_name,
                                             _TASK_ROOT,
                                             require_success=require_success)