Exemplo n.º 1
0
 def class_or_method_args():
     this_args = super(SpeechToText, SpeechToText).class_or_method_args()
     this_args.extend([
         ModuleFlag("transcript_data_pipeline",
                    DataPipeline.REGISTRY_NAME,
                    default=TranscriptDataPipeline.__name__,
                    help="The target side transcript data pipeline."),
         Flag("audio_feature_dim",
              dtype=Flag.TYPE.INTEGER,
              default=80,
              help="The dimension of audio features."),
         Flag("audio_feature_channels",
              dtype=Flag.TYPE.INTEGER,
              default=1,
              help="The number of channels of audio features."),
         Flag("max_src_len",
              dtype=Flag.TYPE.INTEGER,
              default=None,
              help=
              "The maximum source length of training data (audio frames)."),
         Flag(
             "min_src_bucket_boundary",
             dtype=Flag.TYPE.INTEGER,
             default=128,
             help=
             "The minimum source length of the training bucket (audio frames)."
         ),
         Flag("max_trg_len",
              dtype=Flag.TYPE.INTEGER,
              default=None,
              help="The maximum target length of training data."),
         Flag("truncate_src",
              dtype=Flag.TYPE.BOOLEAN,
              default=None,
              help="Whether to truncate source to max_src_len."),
         Flag("truncate_trg",
              dtype=Flag.TYPE.BOOLEAN,
              default=None,
              help="Whether to truncate target to max_trg_len."),
         Flag(
             "experimental_frame_transcript_ratio",
             dtype=Flag.TYPE.INTEGER,
             default=None,
             help=
             "The ratio of the number of frames and its transcript for training batch bucket."
         ),
         Flag(
             "specaug",
             dtype=Flag.TYPE.STRING,
             default=None,
             help=
             "The arguments for spec augment, can be either predefined settings "
             "like LB, LD, SM, SS... or a dict containing detailed arguments."
         ),
         Flag("disable_batch_efficiency",
              dtype=Flag.TYPE.BOOLEAN,
              default=None,
              help="Whether to disable the batch efficiency.")
     ])
     return this_args
Exemplo n.º 2
0
 def class_or_method_args():
     return [
         Flag("eval_steps", dtype=Flag.TYPE.INTEGER, default=1000,
              help="The steps between two validation steps."),
         Flag("eval_start_at", dtype=Flag.TYPE.INTEGER, default=0,
              help="The step to start validation process."),
     ]
Exemplo n.º 3
0
 def class_or_method_args():
     this_args = super(LanguageModel, LanguageModel).class_or_method_args()
     this_args.extend([
         # for creating data pipelines
         ModuleFlag(DataPipeline.REGISTRY_NAME, help="The data pipeline."),
         # for preprocessing data
         Flag("max_len",
              dtype=Flag.TYPE.INTEGER,
              default=None,
              help="The maximum length of training data."),
         Flag("truncate",
              dtype=Flag.TYPE.BOOLEAN,
              default=None,
              help="Whether to truncate data to max_len."),
         # for batching dataset
         Flag("batch_by_tokens",
              dtype=Flag.TYPE.BOOLEAN,
              default=None,
              help="Whether to batch the data by word tokens."),
         Flag(
             "begin_of_sentence",
             dtype=Flag.TYPE.STRING,
             default="bos",
             choices=["bos", "eos"],
             help=
             "The begin of sentence symbol for target side. The choice 'eos' "
             "is for compatibility with fairseq transformer."),
         Flag("gpu_efficient_level",
              dtype=Flag.TYPE.INTEGER,
              default=GPU_EFFICIENT_LEVEL.LEVEL0,
              choices=tuple(GPU_EFFICIENT_LEVEL),
              help="The efficient level for training using XLA, from 0~5."),
     ])
     return this_args
Exemplo n.º 4
0
 def class_or_method_args():
     this_args = super(CommonVoice, CommonVoice).class_or_method_args()
     this_args.extend([
         Flag("extraction", dtype=Flag.TYPE.STRING, default=None,
              choices=CommonVoice.EXTRACTION_CHOICES,
              help="The dataset portion to be extracted, i.e. train, dev, test, other, validated."),
         Flag("language", dtype=Flag.TYPE.STRING, default=None,
              help="the language portion to be extracted, e.g. en, zh-CN. Must be provided "
                   "if the input is a directory.")])
     return this_args
Exemplo n.º 5
0
 def class_or_method_args():
     return [
         Flag("multiple_datasets", dtype=Flag.TYPE.STRING,
              help="A dict of dataset class and parameters, "
                   "where the key is the dataset name and "
                   "the value is a dict of arguments for one dataset."),
         Flag("sample_weights", dtype=Flag.TYPE.FLOAT,
              help="A dict of weights for averaging metrics, where the key "
                   "is the dataset name. 1.0 for each by default.")
     ]
Exemplo n.º 6
0
 def class_or_method_args():
     this_args = super(SequenceGeneratorSavedmodel,
                       SequenceGeneratorSavedmodel).class_or_method_args()
     this_args.extend([
         Flag("export_path", dtype=Flag.TYPE.STRING, default=None,
              help="The path to the savedmodel."),
         Flag("version", dtype=Flag.TYPE.INTEGER, default=1,
              help="The version of the model."),
     ])
     return this_args
 def class_or_method_args():
     return [
         Flag("path", dtype=Flag.TYPE.STRING,
              help="The path to multilingual datasets. "
                   "The record files should be stored in sub directories, which are named by src2trg, "
                   "e.g. rootpath/en2de, rootpath/en2fr..."),
         Flag("auto_switch_langs", dtype=Flag.TYPE.BOOLEAN,
              help="Whether to switch source and target langauges (which will doubled the dataset)."),
         ModuleFlag(DataSampler.REGISTRY_NAME, default=TemperatureSampler.__name__,
                    help="The sampler for unbalanced datasets.")
     ]
Exemplo n.º 8
0
 def class_or_method_args():
     return [
         Flag("schedule_steps",
              dtype=Flag.TYPE.STRING,
              default=None,
              help="A list of triggered steps."),
         Flag("schedule_lrs",
              dtype=Flag.TYPE.STRING,
              default=None,
              help="A list of learning rates."),
     ]
Exemplo n.º 9
0
 def class_or_method_args():
     this_args = super(Translation, Translation).class_or_method_args()
     this_args.extend([
         Flag("gpu_efficient_level", dtype=Flag.TYPE.INTEGER, default=GPU_EFFICIENT_LEVEL.LEVEL0,
              choices=tuple(GPU_EFFICIENT_LEVEL),
              help="The efficient level for training using XLA, from 0~5."),
         Flag("auto_scaling_batch_size", dtype=Flag.TYPE.BOOLEAN, default=None,
              help="Whether to automatically scale up the batch size to match the real tokens "
                   "when `gpu_efficient_level` > 0")
     ])
     return this_args
Exemplo n.º 10
0
 def class_or_method_args():
     return [
         ModuleFlag(Metric.REGISTRY_NAME,
                    help="The evaluation metric for the generation results."),
         ModuleFlag(SequenceSearch.REGISTRY_NAME, help="The search layer for sequence generation."),
         Flag("output_file", dtype=Flag.TYPE.STRING, default=None,
              help="The path to a file for generated outputs. If MultipleDataset is provided, "
                   "it should be a dict like {dataset_name0: data_path0, ...}"),
         Flag("save_metric", dtype=Flag.TYPE.STRING, default=None,
              help="The path to a file that metrics will be saved to, in json format."),
     ]
Exemplo n.º 11
0
 def class_or_method_args():
     return [
         Flag("data_file", dtype=Flag.TYPE.STRING, help="The text file"),
         Flag("data_is_processed",
              dtype=Flag.TYPE.BOOLEAN,
              help="Whether the text data is already processed."),
         Flag("data_lang",
              dtype=Flag.TYPE.STRING,
              default=None,
              help="The language of the text."),
     ]
Exemplo n.º 12
0
 def class_or_method_args():
     return [
         Flag("src_file",
              dtype=Flag.TYPE.STRING,
              help="The source text file"),
         Flag("trg_file",
              dtype=Flag.TYPE.STRING,
              help="The target text file"),
         Flag("data_is_processed",
              dtype=Flag.TYPE.BOOLEAN,
              help="Whether the text data is already processed."),
     ]
Exemplo n.º 13
0
 def class_or_method_args():
     return [
         Flag("data_path",
              dtype=Flag.TYPE.STRING,
              help="The path to TF records."),
         Flag(
             "shuffle_dataset",
             dtype=Flag.TYPE.BOOLEAN,
             help="Whether to shuffle the TF records files. "
             "Note that some parts may be lost under MultiWorkerMirroredStrategy if set True."
         ),
     ]
Exemplo n.º 14
0
 def class_or_method_args():
     return [
         Flag("beam_size",
              dtype=Flag.TYPE.INTEGER,
              default=4,
              help="The beam width of beam search inference."),
         Flag("length_penalty",
              dtype=Flag.TYPE.FLOAT,
              default=0.6,
              help="The length penalty of beam search inference."),
         Flag(
             "top_k",
             dtype=Flag.TYPE.INTEGER,
             default=1,
             help=
             "The number of reserved predictions with top scores of beam search inference."
         ),
         Flag("maximum_decode_length",
              dtype=Flag.TYPE.INTEGER,
              default=None,
              help="The maximum decoding length of beam search inference."),
         Flag("minimum_decode_length",
              dtype=Flag.TYPE.INTEGER,
              default=0,
              help="The minimum decoding length of beam search inference."),
         Flag(
             "extra_decode_length",
             dtype=Flag.TYPE.INTEGER,
             default=50,
             help=
             "The extra decoding length versus source side for beam search inference. "
             "The maximum decoding length of beam search inference will be "
             "source_sequence_length + extra_decode_length if maximum_decode_length "
             "if not provided."),
         Flag(
             "padded_decode",
             dtype=Flag.TYPE.BOOLEAN,
             default=None,
             help=
             "Whether the autoregressive decoding runs with input data padded to "
             "the decode_max_length. For TPU/XLA-GPU runs, this flag has to be "
             "set due the static shape requirement. In addition, this method "
             "will introduce unnecessary overheads which grow quadratically with "
             "the max sequence length."),
         Flag(
             "ensemble_weights",
             dtype=Flag.TYPE.STRING,
             default="average",
             help=
             "The weight scheme for model ensemble, which could be comma-separated numbers."
         ),
         Flag(
             "enable_unk",
             dtype=Flag.TYPE.BOOLEAN,
             default=None,
             help="Whether to enable the search method to generating UNK."),
     ]
Exemplo n.º 15
0
 def class_or_method_args():
     this_args = super(PruneTuneTrainer,
                       PruneTuneTrainer).class_or_method_args()
     this_args.extend([
         Flag("partial_tuning",
              dtype=Flag.TYPE.BOOLEAN,
              default=False,
              help="Train partial weights according to mask"),
         Flag("mask_pkl",
              dtype=Flag.TYPE.STRING,
              default=None,
              help="The file to the masks")
     ])
     return this_args
Exemplo n.º 16
0
 def class_or_method_args():
     this_args = super(PolynomialDecay,
                       PolynomialDecay).class_or_method_args()
     this_args += [
         Flag("initial_sparsity",
              dtype=Flag.TYPE.FLOAT,
              default=0.,
              help="Sparsity (%) at which weight_pruning begins"),
         Flag("polynomial_power",
              dtype=Flag.TYPE.FLOAT,
              default=3,
              help="Exponent to be used in the sparsity function")
     ]
     return this_args
Exemplo n.º 17
0
 def class_or_method_args():
     return [
         Flag("peak_lr",
              dtype=Flag.TYPE.FLOAT,
              default=5e-4,
              help="The configured lr."),
         Flag("init_lr",
              dtype=Flag.TYPE.FLOAT,
              default=0.,
              help="The initial lr."),
         Flag("warmup_steps",
              dtype=Flag.TYPE.INTEGER,
              default=4000,
              help="The number of steps required for linear warmup."),
     ]
Exemplo n.º 18
0
 def class_or_method_args():
     this_args = super(MuSTC, MuSTC).class_or_method_args()
     this_args.append(
         Flag("extraction", dtype=Flag.TYPE.STRING, default=None,
              choices=MuSTC.EXTRACTION_CHOICES,
              help="The dataset portion to be extracted, e.g. train, dev, test (tst-COMMON)."))
     return this_args
Exemplo n.º 19
0
 def class_or_method_args():
     return [
         Flag("sample_sizes",
              dtype=Flag.TYPE.STRING,
              help="A dict. The key is the item name to be sampled, "
              "while the value is the corresponding proportion.")
     ]
Exemplo n.º 20
0
 def class_or_method_args():
     return [
         Flag("input_tarball",
              dtype=Flag.TYPE.STRING,
              default=None,
              help="The original tarball."),
         Flag("excluded_file",
              dtype=Flag.TYPE.STRING,
              default=None,
              help="A file containing transcriptions or translations "
              "that would be removed when reading the corpus "
              "(for filtering out testsets)."),
         ModuleFlag(FeatureExtractor.REGISTRY_NAME,
                    default=None,
                    help="The audio feature extractor.")
     ]
Exemplo n.º 21
0
 def class_or_method_args():
     this_args = TFRecordDataset.class_or_method_args()
     this_args.extend([
         Flag("feature_key",
              dtype=Flag.TYPE.STRING,
              default="audio",
              help="The key of the audio features in the TF Record."),
         Flag(
             "transcript_key",
             dtype=Flag.TYPE.STRING,
             default="transcript",
             help=
             "The key of the audio transcript/translation in the TF Record."
         ),
     ])
     return this_args
Exemplo n.º 22
0
 def class_or_method_args():
     this_args = super(LibriSpeech, LibriSpeech).class_or_method_args()
     this_args.append(
         Flag("excluded_file", dtype=Flag.TYPE.STRING, default=None,
              help="A file containing transcriptions "
                   "that would be removed in the LibriSpeech corpus."))
     return this_args
Exemplo n.º 23
0
 def class_or_method_args():
     """ Returns a list of args for flag definition. """
     flags = super(LabelSmoothedCrossEntropyWithKd,
                   LabelSmoothedCrossEntropyWithKd).class_or_method_args()
     flags.append(Flag("kd_weight", dtype=Flag.TYPE.FLOAT,
                       default=0.1, help="The weight for KD loss."))
     return flags
Exemplo n.º 24
0
 def class_or_method_args():
     this_args = super(SeqGenerationValidator,
                       SeqGenerationValidator).class_or_method_args()
     this_args.extend([
         ModuleFlag("eval_metric",
                    Metric.REGISTRY_NAME,
                    help="The metric for evaluating generation results."),
         ModuleFlag("eval_search_method",
                    SequenceSearch.REGISTRY_NAME,
                    help="The search layer for sequence generation."),
         Flag(
             "eval_estop_patience",
             dtype=Flag.TYPE.INTEGER,
             default=0,
             help=
             "The training process will automatically shut down until the program "
             "fail to acquire a better metric score anymore if `early_stop_patience` greater than 0."
         ),
         Flag(
             "eval_best_checkpoint_path",
             dtype=Flag.TYPE.STRING,
             default=None,
             help=
             "The path for checkpoints with best metric scores if provided,"
             "otherwise, default \"`model_dir`_best\" will be used."),
         Flag(
             "eval_auto_average_checkpoints",
             dtype=Flag.TYPE.BOOLEAN,
             default=True,
             help=
             "Whether to do checkpoint average on all model weights. An extra directory for averaged "
             "weights will be created. It is only available when `eval_best_checkpoint_path` is provided."
         ),
         Flag("eval_best_avg_checkpoint_path",
              dtype=Flag.TYPE.STRING,
              default=None,
              help="The path to saving the averaged checkpoints."),
         Flag(
             "eval_top_checkpoints_to_keep",
             dtype=Flag.TYPE.INTEGER,
             default=10,
             help=
             "The maximum number of checkpoints to be saved (`max_to_keep` for checkpoint manager), "
             "and the number of latest checkpoints to be averaged if `eval_auto_average_checkpoints` is True. "
             "If <= 0, no more checkpoints will be saved."),
     ])
     return this_args
Exemplo n.º 25
0
 def class_or_method_args():
     """ Returns a list of args for flag definition. """
     return [
         Flag("label_smoothing",
              dtype=Flag.TYPE.FLOAT,
              default=0.,
              help="The label smoothing constant.")
     ]
Exemplo n.º 26
0
 def class_or_method_args():
     this_flags = super(MaskSequenceGenerator,
                        MaskSequenceGenerator).class_or_method_args()
     this_flags.append(
         Flag("mask_pkl",
              dtype=Flag.TYPE.STRING,
              default=None,
              help="The path to the mask pkl file."), )
     return this_flags
Exemplo n.º 27
0
 def class_or_method_args():
     this_flags = super(TemperatureSampler,
                        TemperatureSampler).class_or_method_args()
     this_flags.append(
         Flag("temperature",
              dtype=Flag.TYPE.FLOAT,
              default=5,
              help="The temperature for sampling."))
     return this_flags
Exemplo n.º 28
0
 def class_or_method_args():
     return [
         Flag("input_tarball",
              dtype=Flag.TYPE.STRING,
              default=None,
              help="The original tarball."),
         ModuleFlag(FeatureExtractor.REGISTRY_NAME,
                    default=None,
                    help="The audio feature extractor.")
     ]
Exemplo n.º 29
0
 def class_or_method_args():
     this_args = super(WaitkTranslation,
                       WaitkTranslation).class_or_method_args()
     this_args.extend([
         Flag("wait_k",
              dtype=Flag.TYPE.STRING,
              default=None,
              help="The lagging k.")
     ])
     return this_args
Exemplo n.º 30
0
 def class_or_method_args():
     return [
         Flag(
             "data_files",
             dtype=Flag.TYPE.STRING,
             help="A dict of data files. The key is the dataset name while "
             "the value is a dict containing arguments indicating data files."
         ),
         Flag("data_class",
              dtype=Flag.TYPE.STRING,
              help="The dataset class for the data files."),
         Flag("common_properties",
              dtype=Flag.TYPE.STRING,
              default=None,
              help="Other common properties for building a dataset."),
         ModuleFlag(DataSampler.REGISTRY_NAME,
                    default=None,
                    help="The sampler for unbalanced datasets.")
     ]