def class_or_method_args(): this_args = super(SpeechToText, SpeechToText).class_or_method_args() this_args.extend([ ModuleFlag("transcript_data_pipeline", DataPipeline.REGISTRY_NAME, default=TranscriptDataPipeline.__name__, help="The target side transcript data pipeline."), Flag("audio_feature_dim", dtype=Flag.TYPE.INTEGER, default=80, help="The dimension of audio features."), Flag("audio_feature_channels", dtype=Flag.TYPE.INTEGER, default=1, help="The number of channels of audio features."), Flag("max_src_len", dtype=Flag.TYPE.INTEGER, default=None, help= "The maximum source length of training data (audio frames)."), Flag( "min_src_bucket_boundary", dtype=Flag.TYPE.INTEGER, default=128, help= "The minimum source length of the training bucket (audio frames)." ), Flag("max_trg_len", dtype=Flag.TYPE.INTEGER, default=None, help="The maximum target length of training data."), Flag("truncate_src", dtype=Flag.TYPE.BOOLEAN, default=None, help="Whether to truncate source to max_src_len."), Flag("truncate_trg", dtype=Flag.TYPE.BOOLEAN, default=None, help="Whether to truncate target to max_trg_len."), Flag( "experimental_frame_transcript_ratio", dtype=Flag.TYPE.INTEGER, default=None, help= "The ratio of the number of frames and its transcript for training batch bucket." ), Flag( "specaug", dtype=Flag.TYPE.STRING, default=None, help= "The arguments for spec augment, can be either predefined settings " "like LB, LD, SM, SS... or a dict containing detailed arguments." ), Flag("disable_batch_efficiency", dtype=Flag.TYPE.BOOLEAN, default=None, help="Whether to disable the batch efficiency.") ]) return this_args
def class_or_method_args(): return [ Flag("eval_steps", dtype=Flag.TYPE.INTEGER, default=1000, help="The steps between two validation steps."), Flag("eval_start_at", dtype=Flag.TYPE.INTEGER, default=0, help="The step to start validation process."), ]
def class_or_method_args(): this_args = super(LanguageModel, LanguageModel).class_or_method_args() this_args.extend([ # for creating data pipelines ModuleFlag(DataPipeline.REGISTRY_NAME, help="The data pipeline."), # for preprocessing data Flag("max_len", dtype=Flag.TYPE.INTEGER, default=None, help="The maximum length of training data."), Flag("truncate", dtype=Flag.TYPE.BOOLEAN, default=None, help="Whether to truncate data to max_len."), # for batching dataset Flag("batch_by_tokens", dtype=Flag.TYPE.BOOLEAN, default=None, help="Whether to batch the data by word tokens."), Flag( "begin_of_sentence", dtype=Flag.TYPE.STRING, default="bos", choices=["bos", "eos"], help= "The begin of sentence symbol for target side. The choice 'eos' " "is for compatibility with fairseq transformer."), Flag("gpu_efficient_level", dtype=Flag.TYPE.INTEGER, default=GPU_EFFICIENT_LEVEL.LEVEL0, choices=tuple(GPU_EFFICIENT_LEVEL), help="The efficient level for training using XLA, from 0~5."), ]) return this_args
def class_or_method_args(): this_args = super(CommonVoice, CommonVoice).class_or_method_args() this_args.extend([ Flag("extraction", dtype=Flag.TYPE.STRING, default=None, choices=CommonVoice.EXTRACTION_CHOICES, help="The dataset portion to be extracted, i.e. train, dev, test, other, validated."), Flag("language", dtype=Flag.TYPE.STRING, default=None, help="the language portion to be extracted, e.g. en, zh-CN. Must be provided " "if the input is a directory.")]) return this_args
def class_or_method_args(): return [ Flag("multiple_datasets", dtype=Flag.TYPE.STRING, help="A dict of dataset class and parameters, " "where the key is the dataset name and " "the value is a dict of arguments for one dataset."), Flag("sample_weights", dtype=Flag.TYPE.FLOAT, help="A dict of weights for averaging metrics, where the key " "is the dataset name. 1.0 for each by default.") ]
def class_or_method_args(): this_args = super(SequenceGeneratorSavedmodel, SequenceGeneratorSavedmodel).class_or_method_args() this_args.extend([ Flag("export_path", dtype=Flag.TYPE.STRING, default=None, help="The path to the savedmodel."), Flag("version", dtype=Flag.TYPE.INTEGER, default=1, help="The version of the model."), ]) return this_args
def class_or_method_args(): return [ Flag("path", dtype=Flag.TYPE.STRING, help="The path to multilingual datasets. " "The record files should be stored in sub directories, which are named by src2trg, " "e.g. rootpath/en2de, rootpath/en2fr..."), Flag("auto_switch_langs", dtype=Flag.TYPE.BOOLEAN, help="Whether to switch source and target langauges (which will doubled the dataset)."), ModuleFlag(DataSampler.REGISTRY_NAME, default=TemperatureSampler.__name__, help="The sampler for unbalanced datasets.") ]
def class_or_method_args(): return [ Flag("schedule_steps", dtype=Flag.TYPE.STRING, default=None, help="A list of triggered steps."), Flag("schedule_lrs", dtype=Flag.TYPE.STRING, default=None, help="A list of learning rates."), ]
def class_or_method_args(): this_args = super(Translation, Translation).class_or_method_args() this_args.extend([ Flag("gpu_efficient_level", dtype=Flag.TYPE.INTEGER, default=GPU_EFFICIENT_LEVEL.LEVEL0, choices=tuple(GPU_EFFICIENT_LEVEL), help="The efficient level for training using XLA, from 0~5."), Flag("auto_scaling_batch_size", dtype=Flag.TYPE.BOOLEAN, default=None, help="Whether to automatically scale up the batch size to match the real tokens " "when `gpu_efficient_level` > 0") ]) return this_args
def class_or_method_args(): return [ ModuleFlag(Metric.REGISTRY_NAME, help="The evaluation metric for the generation results."), ModuleFlag(SequenceSearch.REGISTRY_NAME, help="The search layer for sequence generation."), Flag("output_file", dtype=Flag.TYPE.STRING, default=None, help="The path to a file for generated outputs. If MultipleDataset is provided, " "it should be a dict like {dataset_name0: data_path0, ...}"), Flag("save_metric", dtype=Flag.TYPE.STRING, default=None, help="The path to a file that metrics will be saved to, in json format."), ]
def class_or_method_args(): return [ Flag("data_file", dtype=Flag.TYPE.STRING, help="The text file"), Flag("data_is_processed", dtype=Flag.TYPE.BOOLEAN, help="Whether the text data is already processed."), Flag("data_lang", dtype=Flag.TYPE.STRING, default=None, help="The language of the text."), ]
def class_or_method_args(): return [ Flag("src_file", dtype=Flag.TYPE.STRING, help="The source text file"), Flag("trg_file", dtype=Flag.TYPE.STRING, help="The target text file"), Flag("data_is_processed", dtype=Flag.TYPE.BOOLEAN, help="Whether the text data is already processed."), ]
def class_or_method_args(): return [ Flag("data_path", dtype=Flag.TYPE.STRING, help="The path to TF records."), Flag( "shuffle_dataset", dtype=Flag.TYPE.BOOLEAN, help="Whether to shuffle the TF records files. " "Note that some parts may be lost under MultiWorkerMirroredStrategy if set True." ), ]
def class_or_method_args(): return [ Flag("beam_size", dtype=Flag.TYPE.INTEGER, default=4, help="The beam width of beam search inference."), Flag("length_penalty", dtype=Flag.TYPE.FLOAT, default=0.6, help="The length penalty of beam search inference."), Flag( "top_k", dtype=Flag.TYPE.INTEGER, default=1, help= "The number of reserved predictions with top scores of beam search inference." ), Flag("maximum_decode_length", dtype=Flag.TYPE.INTEGER, default=None, help="The maximum decoding length of beam search inference."), Flag("minimum_decode_length", dtype=Flag.TYPE.INTEGER, default=0, help="The minimum decoding length of beam search inference."), Flag( "extra_decode_length", dtype=Flag.TYPE.INTEGER, default=50, help= "The extra decoding length versus source side for beam search inference. " "The maximum decoding length of beam search inference will be " "source_sequence_length + extra_decode_length if maximum_decode_length " "if not provided."), Flag( "padded_decode", dtype=Flag.TYPE.BOOLEAN, default=None, help= "Whether the autoregressive decoding runs with input data padded to " "the decode_max_length. For TPU/XLA-GPU runs, this flag has to be " "set due the static shape requirement. In addition, this method " "will introduce unnecessary overheads which grow quadratically with " "the max sequence length."), Flag( "ensemble_weights", dtype=Flag.TYPE.STRING, default="average", help= "The weight scheme for model ensemble, which could be comma-separated numbers." ), Flag( "enable_unk", dtype=Flag.TYPE.BOOLEAN, default=None, help="Whether to enable the search method to generating UNK."), ]
def class_or_method_args(): this_args = super(PruneTuneTrainer, PruneTuneTrainer).class_or_method_args() this_args.extend([ Flag("partial_tuning", dtype=Flag.TYPE.BOOLEAN, default=False, help="Train partial weights according to mask"), Flag("mask_pkl", dtype=Flag.TYPE.STRING, default=None, help="The file to the masks") ]) return this_args
def class_or_method_args(): this_args = super(PolynomialDecay, PolynomialDecay).class_or_method_args() this_args += [ Flag("initial_sparsity", dtype=Flag.TYPE.FLOAT, default=0., help="Sparsity (%) at which weight_pruning begins"), Flag("polynomial_power", dtype=Flag.TYPE.FLOAT, default=3, help="Exponent to be used in the sparsity function") ] return this_args
def class_or_method_args(): return [ Flag("peak_lr", dtype=Flag.TYPE.FLOAT, default=5e-4, help="The configured lr."), Flag("init_lr", dtype=Flag.TYPE.FLOAT, default=0., help="The initial lr."), Flag("warmup_steps", dtype=Flag.TYPE.INTEGER, default=4000, help="The number of steps required for linear warmup."), ]
def class_or_method_args(): this_args = super(MuSTC, MuSTC).class_or_method_args() this_args.append( Flag("extraction", dtype=Flag.TYPE.STRING, default=None, choices=MuSTC.EXTRACTION_CHOICES, help="The dataset portion to be extracted, e.g. train, dev, test (tst-COMMON).")) return this_args
def class_or_method_args(): return [ Flag("sample_sizes", dtype=Flag.TYPE.STRING, help="A dict. The key is the item name to be sampled, " "while the value is the corresponding proportion.") ]
def class_or_method_args(): return [ Flag("input_tarball", dtype=Flag.TYPE.STRING, default=None, help="The original tarball."), Flag("excluded_file", dtype=Flag.TYPE.STRING, default=None, help="A file containing transcriptions or translations " "that would be removed when reading the corpus " "(for filtering out testsets)."), ModuleFlag(FeatureExtractor.REGISTRY_NAME, default=None, help="The audio feature extractor.") ]
def class_or_method_args(): this_args = TFRecordDataset.class_or_method_args() this_args.extend([ Flag("feature_key", dtype=Flag.TYPE.STRING, default="audio", help="The key of the audio features in the TF Record."), Flag( "transcript_key", dtype=Flag.TYPE.STRING, default="transcript", help= "The key of the audio transcript/translation in the TF Record." ), ]) return this_args
def class_or_method_args(): this_args = super(LibriSpeech, LibriSpeech).class_or_method_args() this_args.append( Flag("excluded_file", dtype=Flag.TYPE.STRING, default=None, help="A file containing transcriptions " "that would be removed in the LibriSpeech corpus.")) return this_args
def class_or_method_args(): """ Returns a list of args for flag definition. """ flags = super(LabelSmoothedCrossEntropyWithKd, LabelSmoothedCrossEntropyWithKd).class_or_method_args() flags.append(Flag("kd_weight", dtype=Flag.TYPE.FLOAT, default=0.1, help="The weight for KD loss.")) return flags
def class_or_method_args(): this_args = super(SeqGenerationValidator, SeqGenerationValidator).class_or_method_args() this_args.extend([ ModuleFlag("eval_metric", Metric.REGISTRY_NAME, help="The metric for evaluating generation results."), ModuleFlag("eval_search_method", SequenceSearch.REGISTRY_NAME, help="The search layer for sequence generation."), Flag( "eval_estop_patience", dtype=Flag.TYPE.INTEGER, default=0, help= "The training process will automatically shut down until the program " "fail to acquire a better metric score anymore if `early_stop_patience` greater than 0." ), Flag( "eval_best_checkpoint_path", dtype=Flag.TYPE.STRING, default=None, help= "The path for checkpoints with best metric scores if provided," "otherwise, default \"`model_dir`_best\" will be used."), Flag( "eval_auto_average_checkpoints", dtype=Flag.TYPE.BOOLEAN, default=True, help= "Whether to do checkpoint average on all model weights. An extra directory for averaged " "weights will be created. It is only available when `eval_best_checkpoint_path` is provided." ), Flag("eval_best_avg_checkpoint_path", dtype=Flag.TYPE.STRING, default=None, help="The path to saving the averaged checkpoints."), Flag( "eval_top_checkpoints_to_keep", dtype=Flag.TYPE.INTEGER, default=10, help= "The maximum number of checkpoints to be saved (`max_to_keep` for checkpoint manager), " "and the number of latest checkpoints to be averaged if `eval_auto_average_checkpoints` is True. " "If <= 0, no more checkpoints will be saved."), ]) return this_args
def class_or_method_args(): """ Returns a list of args for flag definition. """ return [ Flag("label_smoothing", dtype=Flag.TYPE.FLOAT, default=0., help="The label smoothing constant.") ]
def class_or_method_args(): this_flags = super(MaskSequenceGenerator, MaskSequenceGenerator).class_or_method_args() this_flags.append( Flag("mask_pkl", dtype=Flag.TYPE.STRING, default=None, help="The path to the mask pkl file."), ) return this_flags
def class_or_method_args(): this_flags = super(TemperatureSampler, TemperatureSampler).class_or_method_args() this_flags.append( Flag("temperature", dtype=Flag.TYPE.FLOAT, default=5, help="The temperature for sampling.")) return this_flags
def class_or_method_args(): return [ Flag("input_tarball", dtype=Flag.TYPE.STRING, default=None, help="The original tarball."), ModuleFlag(FeatureExtractor.REGISTRY_NAME, default=None, help="The audio feature extractor.") ]
def class_or_method_args(): this_args = super(WaitkTranslation, WaitkTranslation).class_or_method_args() this_args.extend([ Flag("wait_k", dtype=Flag.TYPE.STRING, default=None, help="The lagging k.") ]) return this_args
def class_or_method_args(): return [ Flag( "data_files", dtype=Flag.TYPE.STRING, help="A dict of data files. The key is the dataset name while " "the value is a dict containing arguments indicating data files." ), Flag("data_class", dtype=Flag.TYPE.STRING, help="The dataset class for the data files."), Flag("common_properties", dtype=Flag.TYPE.STRING, default=None, help="Other common properties for building a dataset."), ModuleFlag(DataSampler.REGISTRY_NAME, default=None, help="The sampler for unbalanced datasets.") ]