def class_or_method_args(): this_args = super(Seq2Seq, Seq2Seq).class_or_method_args() this_args.extend([ # for creating data pipelines ModuleFlag("src_data_pipeline", DataPipeline.REGISTRY_NAME, help="The source side data pipeline."), ModuleFlag("trg_data_pipeline", DataPipeline.REGISTRY_NAME, help="The target side data pipeline."), # for preprocessing data Flag("max_src_len", dtype=Flag.TYPE.INTEGER, default=None, help="The maximum source length of training data."), Flag("max_trg_len", dtype=Flag.TYPE.INTEGER, default=None, help="The maximum target length of training data."), Flag("truncate_src", dtype=Flag.TYPE.BOOLEAN, default=None, help="Whether to truncate source to max_src_len."), Flag("truncate_trg", dtype=Flag.TYPE.BOOLEAN, default=None, help="Whether to truncate target to max_trg_len."), # for batching dataset Flag("batch_by_tokens", dtype=Flag.TYPE.BOOLEAN, default=None, help="Whether to batch the data by word tokens."), Flag("target_begin_of_sentence", dtype=Flag.TYPE.STRING, default="bos", choices=["bos", "eos"], help="The begin of sentence symbol for target side. The choice 'eos' " "is for compatibility with fairseq transformer.") ]) return this_args
def class_or_method_args(): return [ ModuleFlag(Metric.REGISTRY_NAME, help="The evaluation metric for the generation results."), ModuleFlag(SequenceSearch.REGISTRY_NAME, help="The search layer for sequence generation."), Flag("output_file", dtype=Flag.TYPE.STRING, default=None, help="The path to a file for generated outputs. If MultipleDataset is provided, " "it should be a dict like {dataset_name0: data_path0, ...}"), Flag("save_metric", dtype=Flag.TYPE.STRING, default=None, help="The path to a file that metrics will be saved to, in json format."), ]
def class_or_method_args(): """ Returns a list of args for flag definition. """ this_args = super(SpeechToText, SpeechToText).class_or_method_args() this_args.extend([ ModuleFlag("transcript_data_pipeline", DataPipeline.REGISTRY_NAME, default=TranscriptDataPipeline.__name__, help="The data pipeline for ASR transcription."), ModuleFlag("translation_data_pipeline", DataPipeline.REGISTRY_NAME, default=TranscriptDataPipeline.__name__, help="The data pipeline for translation."), ]) return this_args
def class_or_method_args(): this_args = super(SpeechToText, SpeechToText).class_or_method_args() this_args.extend([ ModuleFlag("transcript_data_pipeline", DataPipeline.REGISTRY_NAME, default=TranscriptDataPipeline.__name__, help="The target side transcript data pipeline."), Flag("audio_feature_dim", dtype=Flag.TYPE.INTEGER, default=80, help="The dimension of audio features."), Flag("audio_feature_channels", dtype=Flag.TYPE.INTEGER, default=1, help="The number of channels of audio features."), Flag("max_src_len", dtype=Flag.TYPE.INTEGER, default=None, help= "The maximum source length of training data (audio frames)."), Flag( "min_src_bucket_boundary", dtype=Flag.TYPE.INTEGER, default=128, help= "The minimum source length of the training bucket (audio frames)." ), Flag("max_trg_len", dtype=Flag.TYPE.INTEGER, default=None, help="The maximum target length of training data."), Flag("truncate_src", dtype=Flag.TYPE.BOOLEAN, default=None, help="Whether to truncate source to max_src_len."), Flag("truncate_trg", dtype=Flag.TYPE.BOOLEAN, default=None, help="Whether to truncate target to max_trg_len."), Flag( "experimental_frame_transcript_ratio", dtype=Flag.TYPE.INTEGER, default=None, help= "The ratio of the number of frames and its transcript for training batch bucket." ), Flag( "specaug", dtype=Flag.TYPE.STRING, default=None, help= "The arguments for spec augment, can be either predefined settings " "like LB, LD, SM, SS... or a dict containing detailed arguments." ), Flag("disable_batch_efficiency", dtype=Flag.TYPE.BOOLEAN, default=None, help="Whether to disable the batch efficiency.") ]) return this_args
def class_or_method_args(): this_args = super(LanguageModel, LanguageModel).class_or_method_args() this_args.extend([ # for creating data pipelines ModuleFlag(DataPipeline.REGISTRY_NAME, help="The data pipeline."), # for preprocessing data Flag("max_len", dtype=Flag.TYPE.INTEGER, default=None, help="The maximum length of training data."), Flag("truncate", dtype=Flag.TYPE.BOOLEAN, default=None, help="Whether to truncate data to max_len."), # for batching dataset Flag("batch_by_tokens", dtype=Flag.TYPE.BOOLEAN, default=None, help="Whether to batch the data by word tokens."), Flag( "begin_of_sentence", dtype=Flag.TYPE.STRING, default="bos", choices=["bos", "eos"], help= "The begin of sentence symbol for target side. The choice 'eos' " "is for compatibility with fairseq transformer."), Flag("gpu_efficient_level", dtype=Flag.TYPE.INTEGER, default=GPU_EFFICIENT_LEVEL.LEVEL0, choices=tuple(GPU_EFFICIENT_LEVEL), help="The efficient level for training using XLA, from 0~5."), ]) return this_args
def class_or_method_args(): this_args = super(SeqGenerationValidator, SeqGenerationValidator).class_or_method_args() this_args.extend([ ModuleFlag("eval_metric", Metric.REGISTRY_NAME, help="The metric for evaluating generation results."), ModuleFlag("eval_search_method", SequenceSearch.REGISTRY_NAME, help="The search layer for sequence generation."), Flag( "eval_estop_patience", dtype=Flag.TYPE.INTEGER, default=0, help= "The training process will automatically shut down until the program " "fail to acquire a better metric score anymore if `early_stop_patience` greater than 0." ), Flag( "eval_best_checkpoint_path", dtype=Flag.TYPE.STRING, default=None, help= "The path for checkpoints with best metric scores if provided," "otherwise, default \"`model_dir`_best\" will be used."), Flag( "eval_auto_average_checkpoints", dtype=Flag.TYPE.BOOLEAN, default=True, help= "Whether to do checkpoint average on all model weights. An extra directory for averaged " "weights will be created. It is only available when `eval_best_checkpoint_path` is provided." ), Flag("eval_best_avg_checkpoint_path", dtype=Flag.TYPE.STRING, default=None, help="The path to saving the averaged checkpoints."), Flag( "eval_top_checkpoints_to_keep", dtype=Flag.TYPE.INTEGER, default=10, help= "The maximum number of checkpoints to be saved (`max_to_keep` for checkpoint manager), " "and the number of latest checkpoints to be averaged if `eval_auto_average_checkpoints` is True. " "If <= 0, no more checkpoints will be saved."), ]) return this_args
def class_or_method_args(): return [ ModuleFlag(Criterion.REGISTRY_NAME, help="The criterion for training or evaluation."), ModuleFlag(OPTIMIZER_REGISTRY_NAME, help="The optimizer for training."), ModuleFlag(LR_SCHEDULE_REGISTRY_NAME, help="The learning schedule for training."), ModuleFlag(Validator.REGISTRY_NAME, help="The validation process while training."), ModuleFlag(PruningSchedule.REGISTRY_NAME, help="The schedule for weight weight_pruning.", default=PolynomialDecay.__name__), Flag("tb_log_dir", dtype=Flag.TYPE.STRING, default=None, help="The path to store tensorboard summary, or `model_dir`/train by default."), Flag("train_steps", dtype=Flag.TYPE.INTEGER, default=10000000, help="The maximum steps for training loop."), Flag("summary_steps", dtype=Flag.TYPE.INTEGER, default=200, help="Doing summary(logging & tensorboard) this every steps."), Flag("save_checkpoint_steps", dtype=Flag.TYPE.INTEGER, default=1000, help="Saving checkpoints this every steps."), Flag("checkpoints_max_to_keep", dtype=Flag.TYPE.INTEGER, default=8, help="The maximum checkpoints to be kept."), Flag("initial_global_step", dtype=Flag.TYPE.INTEGER, default=None, help="The manually specified initial global step."), Flag("pretrain_model", dtype=Flag.TYPE.STRING, default=None, multiple=True, help="The path to a pretrained model directory(a seq2seq model, bert model, etc.). " "(V2) Or a json/yaml-like dict string indicating pretrained models from " "either neurst checkpoints or publicly available models converted " "by neurst.utils.converters. Each entry has the elements: " "path, model_name, from_prefix, to_prefix, name_pattern. " "Multiple pretrain models are also available."), Flag("pretrain_variable_pattern", dtype=Flag.TYPE.STRING, default=None, multiple=True, help="One can restore specified variables in the `pretrain_model` by this regular expression." "Multiple pattern are also available, but must match to `pretrain_model`."), Flag("update_cycle", dtype=Flag.TYPE.INTEGER, default=1, help="Training step with this many batches (Gradient Accumulation)."), Flag("clip_value", dtype=Flag.TYPE.FLOAT, default=None, help="Gradient clipping by value."), Flag("clip_norm", dtype=Flag.TYPE.FLOAT, default=None, help="Gradient clipping by norm."), Flag("experimental_count_batch_num", dtype=Flag.TYPE.BOOLEAN, default=None, help="Pre-scan the dataset for training and count the number of batches."), Flag("freeze_variables", dtype=Flag.TYPE.STRING, default=None, help="Variables whose names are matched with this regex will be freezed."), Flag("pruning_variable_pattern", dtype=Flag.TYPE.STRING, default=None, help="The regular expression that indicates the variables will be pruned."), Flag("nopruning_variable_pattern", dtype=Flag.TYPE.STRING, default=None, help="The regular expression that indicates the variables will NOT be pruned " "(will take effect if `pruning_variable_pattern`=None)."), ]
def class_or_method_args(): return [ Flag("input_tarball", dtype=Flag.TYPE.STRING, default=None, help="The original tarball."), ModuleFlag(FeatureExtractor.REGISTRY_NAME, default=None, help="The audio feature extractor.") ]
def class_or_method_args(): return [ Flag("path", dtype=Flag.TYPE.STRING, help="The path to multilingual datasets. " "The record files should be stored in sub directories, which are named by src2trg, " "e.g. rootpath/en2de, rootpath/en2fr..."), Flag("auto_switch_langs", dtype=Flag.TYPE.BOOLEAN, help="Whether to switch source and target langauges (which will doubled the dataset)."), ModuleFlag(DataSampler.REGISTRY_NAME, default=TemperatureSampler.__name__, help="The sampler for unbalanced datasets.") ]
def class_or_method_args(): this_args = super(CriterionValidator, CriterionValidator).class_or_method_args() this_args.extend([ ModuleFlag("eval_criterion", Criterion.REGISTRY_NAME, help="The criterion for validation."), ModuleFlag("eval_dataset", Dataset.REGISTRY_NAME, help="The dataset for validation."), Flag("eval_batch_size", dtype=Flag.TYPE.INTEGER, default=32, help="The batch size for validation process."), Flag("eval_task_args", dtype=Flag.TYPE.STRING, default=None, help="Other parameters for building validation dataset.") ]) return this_args
def class_or_method_args(): this_args = super(Seq2Seq, Seq2Seq).class_or_method_args() this_args.extend([ # for creating data pipelines ModuleFlag("src_data_pipeline", DataPipeline.REGISTRY_NAME, help="The source side data pipeline."), ModuleFlag("trg_data_pipeline", DataPipeline.REGISTRY_NAME, help="The target side data pipeline."), # for preprocessing data Flag("max_src_len", dtype=Flag.TYPE.INTEGER, default=None, help="The maximum source length of training data."), Flag("max_trg_len", dtype=Flag.TYPE.INTEGER, default=None, help="The maximum target length of training data."), Flag("truncate_src", dtype=Flag.TYPE.BOOLEAN, default=None, help="Whether to truncate source to max_src_len."), Flag("truncate_trg", dtype=Flag.TYPE.BOOLEAN, default=None, help="Whether to truncate target to max_trg_len."), # for batching dataset Flag("batch_by_tokens", dtype=Flag.TYPE.BOOLEAN, default=None, help="Whether to batch the data by word tokens."), ]) return this_args
def class_or_method_args(): return [ ModuleFlag(Encoder.REGISTRY_NAME, default=None, help="The encoder."), ModuleFlag(Decoder.REGISTRY_NAME, default=None, help="The decoder."), Flag("modality.share_source_target_embedding", dtype=Flag.TYPE.BOOLEAN, default=False, help="Whether to share source and target embedding table."), Flag("modality.share_embedding_and_softmax_weights", dtype=Flag.TYPE.BOOLEAN, default=False, help="Whether to share the target embedding table and softmax weights."), Flag("modality.dim", dtype=Flag.TYPE.INTEGER, default=None, help="The default embedding dimension for both source and target side."), Flag("modality.source.dim", dtype=Flag.TYPE.INTEGER, default=None, help="The source-side embedding dimension, or `modality.dim` if not provided."), Flag("modality.target.dim", dtype=Flag.TYPE.INTEGER, default=None, help="The target-side embedding dimension, or `modality.dim` if not provided."), Flag("modality.timing", dtype=Flag.TYPE.STRING, default=None, help="The arbitrary parameters for positional encoding of both source and target side."), Flag("modality.source.timing", dtype=Flag.TYPE.STRING, default=None, help="The arbitrary parameters for source-side positional encoding, " "or `modality.timing` if not provided."), Flag("modality.target.timing", dtype=Flag.TYPE.STRING, default=None, help="The arbitrary parameters for target-side positional encoding, " "or `modality.timing` if not provided.") ]
def class_or_method_args(): return [ Flag("input_tarball", dtype=Flag.TYPE.STRING, default=None, help="The original tarball."), Flag("excluded_file", dtype=Flag.TYPE.STRING, default=None, help="A file containing transcriptions or translations " "that would be removed when reading the corpus " "(for filtering out testsets)."), ModuleFlag(FeatureExtractor.REGISTRY_NAME, default=None, help="The audio feature extractor.") ]
def class_or_method_args(): return [ Flag( "data_files", dtype=Flag.TYPE.STRING, help="A dict of data files. The key is the dataset name while " "the value is a dict containing arguments indicating data files." ), Flag("data_class", dtype=Flag.TYPE.STRING, help="The dataset class for the data files."), Flag("common_properties", dtype=Flag.TYPE.STRING, default=None, help="Other common properties for building a dataset."), ModuleFlag(DataSampler.REGISTRY_NAME, default=None, help="The sampler for unbalanced datasets.") ]
def class_or_method_args(): return [ Flag( "tb_log_dir", dtype=Flag.TYPE.STRING, default=None, help= "The path to store tensorboard summary, or `model_dir`/validation by default." ), Flag("waiting_interval", dtype=Flag.TYPE.INTEGER, default=120, help="The waiting interval between two evaluation steps."), Flag("maximum_waiting_time", dtype=Flag.TYPE.INTEGER, default=3600, help="The maximum waiting time(in seconds)."), ModuleFlag(Validator.REGISTRY_NAME, help="The validation process during training."), ]
def class_or_method_args(): return [ Flag( "data_files", dtype=Flag.TYPE.STRING, help= "A dict of parallel data files. The key is the dataset name while " "the value is a dict containing `src_file` and `trg_file`."), Flag("data_is_processed", dtype=Flag.TYPE.BOOLEAN, help="Whether the text data is already processed."), Flag("src_lang", dtype=Flag.TYPE.STRING, default=None, help="The source language"), Flag("trg_lang", dtype=Flag.TYPE.STRING, default=None, help="The target language"), ModuleFlag(DataSampler.REGISTRY_NAME, default=None, help="The sampler for unbalanced datasets.") ]
def class_or_method_args(): return [ ModuleFlag(Criterion.REGISTRY_NAME, help="The criterion for evaluation."), ]
def class_or_method_args(): return [ ModuleFlag(Criterion.REGISTRY_NAME, help="The criterion for training or evaluation."), ModuleFlag(OPTIMIZER_REGISTRY_NAME, help="The optimizer for training."), ModuleFlag(LR_SCHEDULE_REGISTRY_NAME, help="The learning schedule for training."), ModuleFlag(Validator.REGISTRY_NAME, help="The validation process while training."), Flag( "tb_log_dir", dtype=Flag.TYPE.STRING, default=None, help= "The path to store tensorboard summary, or `model_dir`/train by default." ), Flag("train_steps", dtype=Flag.TYPE.INTEGER, default=10000000, help="The maximum steps for training loop."), Flag( "summary_steps", dtype=Flag.TYPE.INTEGER, default=200, help="Doing summary(logging & tensorboard) this every steps."), Flag("save_checkpoint_steps", dtype=Flag.TYPE.INTEGER, default=1000, help="Saving checkpoints this every steps."), Flag("checkpoints_max_to_keep", dtype=Flag.TYPE.INTEGER, default=8, help="The maximum checkpoints to be kept."), Flag("initial_global_step", dtype=Flag.TYPE.INTEGER, default=None, help="The manually specified initial global step."), Flag( "pretrain_model", dtype=Flag.TYPE.STRING, default=None, multiple=True, help= "The path to a pretrained model directory(a seq2seq model, bert model, etc.). " "Multiple pretrain models are also available."), Flag( "pretrain_variable_pattern", dtype=Flag.TYPE.STRING, default=None, multiple=True, help= "One can restore specified variables in the `pretrain_model` by this regular expression." "Multiple pattern are also available, but must match to `pretrain_model`." ), Flag( "update_cycle", dtype=Flag.TYPE.INTEGER, default=1, help= "Training step with this many batches (Gradient Accumulation)." ), Flag("clip_value", dtype=Flag.TYPE.FLOAT, default=None, help="Gradient clipping by value."), Flag("clip_norm", dtype=Flag.TYPE.FLOAT, default=None, help="Gradient clipping by norm."), Flag( "experimental_count_batch_num", dtype=Flag.TYPE.BOOLEAN, default=None, help= "Pre-scan the dataset for training and count the number of batches." ) ]