예제 #1
0
from absl import app, logging

import neurst.utils.flags_core as flags_core
from neurst.data.data_pipelines.data_pipeline import lowercase_and_remove_punctuations
from neurst.data.text import Tokenizer, build_tokenizer

FLAG_LIST = [
    flags_core.Flag("input", dtype=flags_core.Flag.TYPE.STRING, default=None,
                    help="The path to the input text file."),
    flags_core.Flag("output", dtype=flags_core.Flag.TYPE.STRING, default=None,
                    help="The path to the output text file."),
    flags_core.Flag("lowercase", dtype=flags_core.Flag.TYPE.BOOLEAN, default=None,
                    help="Whether to lowercase."),
    flags_core.Flag("remove_punctuation", dtype=flags_core.Flag.TYPE.BOOLEAN, default=None,
                    help="Whether to remove the punctuations."),
    flags_core.ModuleFlag(Tokenizer.REGISTRY_NAME, help="The tokenizer."),
]


def _main(_):
    arg_parser = flags_core.define_flags(FLAG_LIST, with_config_file=False)
    args, remaining_argv = flags_core.intelligent_parse_flags(FLAG_LIST, arg_parser)
    flags_core.verbose_flags(FLAG_LIST, args, remaining_argv)

    tokenizer = build_tokenizer(args)
    with tf.io.gfile.GFile(args["input"]) as fp:
        with tf.io.gfile.GFile(args["output"], "w") as fw:
            for line in fp:
                line = lowercase_and_remove_punctuations(tokenizer.language, line.strip(),
                                                         args["lowercase"], args["remove_punctuation"])
                fw.write(tokenizer.tokenize(line, return_str=True) + "\n")
예제 #2
0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from absl import app, logging

import neurst.utils.flags_core as flags_core
from neurst.metrics import Metric, build_metric
from neurst.utils.misc import flatten_string_list

FLAG_LIST = [
    flags_core.Flag("hypo_file", dtype=flags_core.Flag.TYPE.STRING, default=None,
                    help="The path to hypothesis file."),
    flags_core.Flag("ref_file", dtype=flags_core.Flag.TYPE.STRING, default=None, multiple=True,
                    help="The path to reference file. "),
    flags_core.ModuleFlag(Metric.REGISTRY_NAME, help="The metric for evaluation."),
]


def evaluate(metric, hypo_file, ref_file):
    assert metric is not None
    assert hypo_file
    assert ref_file
    with tf.io.gfile.GFile(hypo_file) as fp:
        hypo = [line.strip() for line in fp]

    ref_list = []
    for one_ref_file in flatten_string_list(ref_file):
        with tf.io.gfile.GFile(one_ref_file) as fp:
            ref = [line.strip() for line in fp]
            ref_list.append(ref)
예제 #3
0
        default="train.tfrecords-%5.5d-of-%5.5d",
        help=
        "The template name of output tfrecords, like train.tfrecords-%5.5d-of-%5.5d."
    ),
    flags_core.Flag(
        "mode",
        dtype=flags_core.Flag.TYPE.STRING,
        default=ModeKeys.TRAIN,
        choices=ModeKeys._fields,  # pylint: disable=protected-access
        help="The mode to acquire data preprocess method, "
        "that is, the result TF Record dataset will be used for."),
    flags_core.Flag("progressbar",
                    dtype=flags_core.Flag.TYPE.BOOLEAN,
                    default=None,
                    help="Whether to dispaly the progressbar"),
    flags_core.ModuleFlag(Task.REGISTRY_NAME,
                          help="The binding task for data pre-processing."),
    flags_core.ModuleFlag(Dataset.REGISTRY_NAME, help="The raw dataset."),
]


def _format_tf_feature(feature, dtype):
    if dtype is str:
        feature = tf.nest.map_structure(lambda _x: _x.encode("utf-8"), feature)
    value = numpy.array(feature).flatten()
    if dtype is int:
        return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
    elif dtype is float:
        return tf.train.Feature(float_list=tf.train.FloatList(value=value))
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

예제 #4
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy
from absl import app, logging

import neurst.utils.flags_core as flags_core
from neurst.data.audio import FeatureExtractor, build_feature_extractor
from neurst.data.audio.float_identity import FloatIdentity
from neurst.data.datasets import Dataset, build_dataset
from neurst.data.datasets.audio.audio_dataset import AudioTripleTFRecordDataset

FLAG_LIST = [
    flags_core.ModuleFlag(Dataset.REGISTRY_NAME,
                          help="The audio TFRecord dataset.",
                          default=AudioTripleTFRecordDataset.__name__),
    flags_core.ModuleFlag(
        FeatureExtractor.REGISTRY_NAME,
        help="The feature extractor already applied on the audio.",
        default=FloatIdentity.__name__),
]

_DISPLAY_PERCENTS = [0.1, 0.3, 0.5, 0.8, 0.9, 0.95, 0.99, 0.999]


class BigCounter(object):
    def __init__(self, base=1000):
        self._base = base
        self._values = dict()
        self.min_value = 1000000
예제 #5
0
    flags_core.Flag(
        "hparams_set",
        dtype=flags_core.Flag.TYPE.STRING,
        help="A string indicating a set of pre-defined hyper-parameters, "
        "e.g. transformer_base, transformer_big or transformer_768_16e_3d."),
    flags_core.Flag("model_dir",
                    dtype=flags_core.Flag.TYPE.STRING,
                    help="The path to the checkpoint for saving and loading."),
    flags_core.Flag("enable_quant",
                    dtype=flags_core.Flag.TYPE.BOOLEAN,
                    default=False,
                    help="Whether to enable quantization for finetuning."),
    flags_core.Flag("quant_params",
                    dtype=flags_core.Flag.TYPE.STRING,
                    help="A dict of parameters for quantization."),
    flags_core.ModuleFlag(BaseExperiment.REGISTRY_NAME, help="The program."),
    flags_core.ModuleFlag(Task.REGISTRY_NAME, help="The binding task."),
    flags_core.ModuleFlag(BaseModel.REGISTRY_NAME, help="The model."),
    flags_core.ModuleFlag(Dataset.REGISTRY_NAME, help="The dataset."),
]


def _pre_load_args(args):
    cfg_file_args = yaml_load_checking(
        load_from_config_path(
            flatten_string_list(
                getattr(args, flags_core.DEFAULT_CONFIG_FLAG.name))))
    model_dirs = flatten_string_list(args.model_dir
                                     or cfg_file_args.get("model_dir", None))
    hparams_set = args.hparams_set
    if hparams_set is None:
예제 #6
0
                 help="The path to the ASR model checkpoint."),
 flags_core.Flag("mt_model_dir",
                 dtype=flags_core.Flag.TYPE.STRING,
                 help="The path to the MT model checkpoint."),
 flags_core.Flag("asr_output_file",
                 dtype=flags_core.Flag.TYPE.STRING,
                 help="The path to save ASR hypothesis."),
 flags_core.Flag("mt_output_file",
                 dtype=flags_core.Flag.TYPE.STRING,
                 help="The path to save MT hypothesis."),
 flags_core.Flag("batch_size",
                 dtype=flags_core.Flag.TYPE.INTEGER,
                 default=32,
                 help="The batch size for inference."),
 flags_core.ModuleFlag(Dataset.REGISTRY_NAME,
                       default=AudioTripleTFRecordDataset.__name__,
                       help="The audio dataset."),
 flags_core.ModuleFlag("asr_" + SequenceSearch.REGISTRY_NAME,
                       module_name=SequenceSearch.REGISTRY_NAME,
                       default="beam_search",
                       help="The search method for ASR."),
 flags_core.ModuleFlag("mt_" + SequenceSearch.REGISTRY_NAME,
                       module_name=SequenceSearch.REGISTRY_NAME,
                       default="beam_search",
                       help="The search method for MT."),
 flags_core.ModuleFlag("asr_" + Metric.REGISTRY_NAME,
                       module_name=Metric.REGISTRY_NAME,
                       default="wer",
                       help="The metric to evaluate ASR output."),
 flags_core.ModuleFlag("mt_" + Metric.REGISTRY_NAME,
                       module_name=Metric.REGISTRY_NAME,
예제 #7
0
from absl import app, logging

import neurst.utils.flags_core as flags_core
from neurst.data.datasets import Dataset, build_dataset
from neurst.data.datasets.audio.audio_dataset import RawAudioDataset

FLAG_LIST = [
    flags_core.Flag("output_transcript_file",
                    dtype=flags_core.Flag.TYPE.STRING,
                    required=True,
                    help="The path to save transcriptions."),
    flags_core.Flag("output_translation_file",
                    dtype=flags_core.Flag.TYPE.STRING,
                    default=None,
                    help="The path to save transcriptions."),
    flags_core.ModuleFlag(Dataset.REGISTRY_NAME, help="The raw dataset."),
]


def main(dataset, output_transcript_file, output_translation_file=None):
    assert isinstance(dataset, RawAudioDataset)
    transcripts = dataset.transcripts
    translations = dataset.translations
    assert transcripts, "Fail to extract transcripts."
    with tf.io.gfile.GFile(output_transcript_file, "w") as fw:
        fw.write("\n".join(transcripts) + "\n")
    if translations and output_translation_file:
        with tf.io.gfile.GFile(output_translation_file, "w") as fw:
            fw.write("\n".join(translations) + "\n")