示例#1
0
 def class_or_method_args():
     this_args = super(Seq2Seq, Seq2Seq).class_or_method_args()
     this_args.extend([
         # for creating data pipelines
         ModuleFlag("src_data_pipeline", DataPipeline.REGISTRY_NAME,
                    help="The source side data pipeline."),
         ModuleFlag("trg_data_pipeline", DataPipeline.REGISTRY_NAME,
                    help="The target side data pipeline."),
         # for preprocessing data
         Flag("max_src_len", dtype=Flag.TYPE.INTEGER, default=None,
              help="The maximum source length of training data."),
         Flag("max_trg_len", dtype=Flag.TYPE.INTEGER, default=None,
              help="The maximum target length of training data."),
         Flag("truncate_src", dtype=Flag.TYPE.BOOLEAN, default=None,
              help="Whether to truncate source to max_src_len."),
         Flag("truncate_trg", dtype=Flag.TYPE.BOOLEAN, default=None,
              help="Whether to truncate target to max_trg_len."),
         # for batching dataset
         Flag("batch_by_tokens", dtype=Flag.TYPE.BOOLEAN, default=None,
              help="Whether to batch the data by word tokens."),
         Flag("target_begin_of_sentence", dtype=Flag.TYPE.STRING, default="bos",
              choices=["bos", "eos"],
              help="The begin of sentence symbol for target side. The choice 'eos' "
                   "is for compatibility with fairseq transformer.")
     ])
     return this_args
示例#2
0
 def class_or_method_args():
     return [
         ModuleFlag(Metric.REGISTRY_NAME,
                    help="The evaluation metric for the generation results."),
         ModuleFlag(SequenceSearch.REGISTRY_NAME, help="The search layer for sequence generation."),
         Flag("output_file", dtype=Flag.TYPE.STRING, default=None,
              help="The path to a file for generated outputs. If MultipleDataset is provided, "
                   "it should be a dict like {dataset_name0: data_path0, ...}"),
         Flag("save_metric", dtype=Flag.TYPE.STRING, default=None,
              help="The path to a file that metrics will be saved to, in json format."),
     ]
示例#3
0
 def class_or_method_args():
     """ Returns a list of args for flag definition. """
     this_args = super(SpeechToText, SpeechToText).class_or_method_args()
     this_args.extend([
         ModuleFlag("transcript_data_pipeline",
                    DataPipeline.REGISTRY_NAME,
                    default=TranscriptDataPipeline.__name__,
                    help="The data pipeline for ASR transcription."),
         ModuleFlag("translation_data_pipeline",
                    DataPipeline.REGISTRY_NAME,
                    default=TranscriptDataPipeline.__name__,
                    help="The data pipeline for translation."),
     ])
     return this_args
示例#4
0
 def class_or_method_args():
     this_args = super(SpeechToText, SpeechToText).class_or_method_args()
     this_args.extend([
         ModuleFlag("transcript_data_pipeline",
                    DataPipeline.REGISTRY_NAME,
                    default=TranscriptDataPipeline.__name__,
                    help="The target side transcript data pipeline."),
         Flag("audio_feature_dim",
              dtype=Flag.TYPE.INTEGER,
              default=80,
              help="The dimension of audio features."),
         Flag("audio_feature_channels",
              dtype=Flag.TYPE.INTEGER,
              default=1,
              help="The number of channels of audio features."),
         Flag("max_src_len",
              dtype=Flag.TYPE.INTEGER,
              default=None,
              help=
              "The maximum source length of training data (audio frames)."),
         Flag(
             "min_src_bucket_boundary",
             dtype=Flag.TYPE.INTEGER,
             default=128,
             help=
             "The minimum source length of the training bucket (audio frames)."
         ),
         Flag("max_trg_len",
              dtype=Flag.TYPE.INTEGER,
              default=None,
              help="The maximum target length of training data."),
         Flag("truncate_src",
              dtype=Flag.TYPE.BOOLEAN,
              default=None,
              help="Whether to truncate source to max_src_len."),
         Flag("truncate_trg",
              dtype=Flag.TYPE.BOOLEAN,
              default=None,
              help="Whether to truncate target to max_trg_len."),
         Flag(
             "experimental_frame_transcript_ratio",
             dtype=Flag.TYPE.INTEGER,
             default=None,
             help=
             "The ratio of the number of frames and its transcript for training batch bucket."
         ),
         Flag(
             "specaug",
             dtype=Flag.TYPE.STRING,
             default=None,
             help=
             "The arguments for spec augment, can be either predefined settings "
             "like LB, LD, SM, SS... or a dict containing detailed arguments."
         ),
         Flag("disable_batch_efficiency",
              dtype=Flag.TYPE.BOOLEAN,
              default=None,
              help="Whether to disable the batch efficiency.")
     ])
     return this_args
示例#5
0
 def class_or_method_args():
     this_args = super(LanguageModel, LanguageModel).class_or_method_args()
     this_args.extend([
         # for creating data pipelines
         ModuleFlag(DataPipeline.REGISTRY_NAME, help="The data pipeline."),
         # for preprocessing data
         Flag("max_len",
              dtype=Flag.TYPE.INTEGER,
              default=None,
              help="The maximum length of training data."),
         Flag("truncate",
              dtype=Flag.TYPE.BOOLEAN,
              default=None,
              help="Whether to truncate data to max_len."),
         # for batching dataset
         Flag("batch_by_tokens",
              dtype=Flag.TYPE.BOOLEAN,
              default=None,
              help="Whether to batch the data by word tokens."),
         Flag(
             "begin_of_sentence",
             dtype=Flag.TYPE.STRING,
             default="bos",
             choices=["bos", "eos"],
             help=
             "The begin of sentence symbol for target side. The choice 'eos' "
             "is for compatibility with fairseq transformer."),
         Flag("gpu_efficient_level",
              dtype=Flag.TYPE.INTEGER,
              default=GPU_EFFICIENT_LEVEL.LEVEL0,
              choices=tuple(GPU_EFFICIENT_LEVEL),
              help="The efficient level for training using XLA, from 0~5."),
     ])
     return this_args
示例#6
0
 def class_or_method_args():
     this_args = super(SeqGenerationValidator,
                       SeqGenerationValidator).class_or_method_args()
     this_args.extend([
         ModuleFlag("eval_metric",
                    Metric.REGISTRY_NAME,
                    help="The metric for evaluating generation results."),
         ModuleFlag("eval_search_method",
                    SequenceSearch.REGISTRY_NAME,
                    help="The search layer for sequence generation."),
         Flag(
             "eval_estop_patience",
             dtype=Flag.TYPE.INTEGER,
             default=0,
             help=
             "The training process will automatically shut down until the program "
             "fail to acquire a better metric score anymore if `early_stop_patience` greater than 0."
         ),
         Flag(
             "eval_best_checkpoint_path",
             dtype=Flag.TYPE.STRING,
             default=None,
             help=
             "The path for checkpoints with best metric scores if provided,"
             "otherwise, default \"`model_dir`_best\" will be used."),
         Flag(
             "eval_auto_average_checkpoints",
             dtype=Flag.TYPE.BOOLEAN,
             default=True,
             help=
             "Whether to do checkpoint average on all model weights. An extra directory for averaged "
             "weights will be created. It is only available when `eval_best_checkpoint_path` is provided."
         ),
         Flag("eval_best_avg_checkpoint_path",
              dtype=Flag.TYPE.STRING,
              default=None,
              help="The path to saving the averaged checkpoints."),
         Flag(
             "eval_top_checkpoints_to_keep",
             dtype=Flag.TYPE.INTEGER,
             default=10,
             help=
             "The maximum number of checkpoints to be saved (`max_to_keep` for checkpoint manager), "
             "and the number of latest checkpoints to be averaged if `eval_auto_average_checkpoints` is True. "
             "If <= 0, no more checkpoints will be saved."),
     ])
     return this_args
示例#7
0
 def class_or_method_args():
     return [
         ModuleFlag(Criterion.REGISTRY_NAME, help="The criterion for training or evaluation."),
         ModuleFlag(OPTIMIZER_REGISTRY_NAME, help="The optimizer for training."),
         ModuleFlag(LR_SCHEDULE_REGISTRY_NAME, help="The learning schedule for training."),
         ModuleFlag(Validator.REGISTRY_NAME, help="The validation process while training."),
         ModuleFlag(PruningSchedule.REGISTRY_NAME, help="The schedule for weight weight_pruning.",
                    default=PolynomialDecay.__name__),
         Flag("tb_log_dir", dtype=Flag.TYPE.STRING, default=None,
              help="The path to store tensorboard summary, or `model_dir`/train by default."),
         Flag("train_steps", dtype=Flag.TYPE.INTEGER, default=10000000,
              help="The maximum steps for training loop."),
         Flag("summary_steps", dtype=Flag.TYPE.INTEGER, default=200,
              help="Doing summary(logging & tensorboard) this every steps."),
         Flag("save_checkpoint_steps", dtype=Flag.TYPE.INTEGER, default=1000,
              help="Saving checkpoints this every steps."),
         Flag("checkpoints_max_to_keep", dtype=Flag.TYPE.INTEGER, default=8,
              help="The maximum checkpoints to be kept."),
         Flag("initial_global_step", dtype=Flag.TYPE.INTEGER, default=None,
              help="The manually specified initial global step."),
         Flag("pretrain_model", dtype=Flag.TYPE.STRING, default=None, multiple=True,
              help="The path to a pretrained model directory(a seq2seq model, bert model, etc.). "
                   "(V2) Or a json/yaml-like dict string indicating pretrained models from "
                   "either neurst checkpoints or publicly available models converted "
                   "by neurst.utils.converters. Each entry has the elements: "
                   "path, model_name, from_prefix, to_prefix, name_pattern. "
                   "Multiple pretrain models are also available."),
         Flag("pretrain_variable_pattern", dtype=Flag.TYPE.STRING, default=None, multiple=True,
              help="One can restore specified variables in the `pretrain_model` by this regular expression."
                   "Multiple pattern are also available, but must match to `pretrain_model`."),
         Flag("update_cycle", dtype=Flag.TYPE.INTEGER, default=1,
              help="Training step with this many batches (Gradient Accumulation)."),
         Flag("clip_value", dtype=Flag.TYPE.FLOAT, default=None, help="Gradient clipping by value."),
         Flag("clip_norm", dtype=Flag.TYPE.FLOAT, default=None, help="Gradient clipping by norm."),
         Flag("experimental_count_batch_num", dtype=Flag.TYPE.BOOLEAN, default=None,
              help="Pre-scan the dataset for training and count the number of batches."),
         Flag("freeze_variables", dtype=Flag.TYPE.STRING, default=None,
              help="Variables whose names are matched with this regex will be freezed."),
         Flag("pruning_variable_pattern", dtype=Flag.TYPE.STRING, default=None,
              help="The regular expression that indicates the variables will be pruned."),
         Flag("nopruning_variable_pattern", dtype=Flag.TYPE.STRING, default=None,
              help="The regular expression that indicates the variables will NOT be pruned "
                   "(will take effect if `pruning_variable_pattern`=None)."),
     ]
示例#8
0
 def class_or_method_args():
     return [
         Flag("input_tarball",
              dtype=Flag.TYPE.STRING,
              default=None,
              help="The original tarball."),
         ModuleFlag(FeatureExtractor.REGISTRY_NAME,
                    default=None,
                    help="The audio feature extractor.")
     ]
 def class_or_method_args():
     return [
         Flag("path", dtype=Flag.TYPE.STRING,
              help="The path to multilingual datasets. "
                   "The record files should be stored in sub directories, which are named by src2trg, "
                   "e.g. rootpath/en2de, rootpath/en2fr..."),
         Flag("auto_switch_langs", dtype=Flag.TYPE.BOOLEAN,
              help="Whether to switch source and target langauges (which will doubled the dataset)."),
         ModuleFlag(DataSampler.REGISTRY_NAME, default=TemperatureSampler.__name__,
                    help="The sampler for unbalanced datasets.")
     ]
示例#10
0
 def class_or_method_args():
     this_args = super(CriterionValidator,
                       CriterionValidator).class_or_method_args()
     this_args.extend([
         ModuleFlag("eval_criterion",
                    Criterion.REGISTRY_NAME,
                    help="The criterion for validation."),
         ModuleFlag("eval_dataset",
                    Dataset.REGISTRY_NAME,
                    help="The dataset for validation."),
         Flag("eval_batch_size",
              dtype=Flag.TYPE.INTEGER,
              default=32,
              help="The batch size for validation process."),
         Flag("eval_task_args",
              dtype=Flag.TYPE.STRING,
              default=None,
              help="Other parameters for building validation dataset.")
     ])
     return this_args
示例#11
0
 def class_or_method_args():
     this_args = super(Seq2Seq, Seq2Seq).class_or_method_args()
     this_args.extend([
         # for creating data pipelines
         ModuleFlag("src_data_pipeline", DataPipeline.REGISTRY_NAME,
                    help="The source side data pipeline."),
         ModuleFlag("trg_data_pipeline", DataPipeline.REGISTRY_NAME,
                    help="The target side data pipeline."),
         # for preprocessing data
         Flag("max_src_len", dtype=Flag.TYPE.INTEGER, default=None,
              help="The maximum source length of training data."),
         Flag("max_trg_len", dtype=Flag.TYPE.INTEGER, default=None,
              help="The maximum target length of training data."),
         Flag("truncate_src", dtype=Flag.TYPE.BOOLEAN, default=None,
              help="Whether to truncate source to max_src_len."),
         Flag("truncate_trg", dtype=Flag.TYPE.BOOLEAN, default=None,
              help="Whether to truncate target to max_trg_len."),
         # for batching dataset
         Flag("batch_by_tokens", dtype=Flag.TYPE.BOOLEAN, default=None,
              help="Whether to batch the data by word tokens."),
     ])
     return this_args
示例#12
0
 def class_or_method_args():
     return [
         ModuleFlag(Encoder.REGISTRY_NAME, default=None, help="The encoder."),
         ModuleFlag(Decoder.REGISTRY_NAME, default=None, help="The decoder."),
         Flag("modality.share_source_target_embedding", dtype=Flag.TYPE.BOOLEAN, default=False,
              help="Whether to share source and target embedding table."),
         Flag("modality.share_embedding_and_softmax_weights", dtype=Flag.TYPE.BOOLEAN, default=False,
              help="Whether to share the target embedding table and softmax weights."),
         Flag("modality.dim", dtype=Flag.TYPE.INTEGER, default=None,
              help="The default embedding dimension for both source and target side."),
         Flag("modality.source.dim", dtype=Flag.TYPE.INTEGER, default=None,
              help="The source-side embedding dimension, or `modality.dim` if not provided."),
         Flag("modality.target.dim", dtype=Flag.TYPE.INTEGER, default=None,
              help="The target-side embedding dimension, or `modality.dim` if not provided."),
         Flag("modality.timing", dtype=Flag.TYPE.STRING, default=None,
              help="The arbitrary parameters for positional encoding of both source and target side."),
         Flag("modality.source.timing", dtype=Flag.TYPE.STRING, default=None,
              help="The arbitrary parameters for source-side positional encoding, "
                   "or `modality.timing` if not provided."),
         Flag("modality.target.timing", dtype=Flag.TYPE.STRING, default=None,
              help="The arbitrary parameters for target-side positional encoding, "
                   "or `modality.timing` if not provided.")
     ]
示例#13
0
 def class_or_method_args():
     return [
         Flag("input_tarball",
              dtype=Flag.TYPE.STRING,
              default=None,
              help="The original tarball."),
         Flag("excluded_file",
              dtype=Flag.TYPE.STRING,
              default=None,
              help="A file containing transcriptions or translations "
              "that would be removed when reading the corpus "
              "(for filtering out testsets)."),
         ModuleFlag(FeatureExtractor.REGISTRY_NAME,
                    default=None,
                    help="The audio feature extractor.")
     ]
示例#14
0
 def class_or_method_args():
     return [
         Flag(
             "data_files",
             dtype=Flag.TYPE.STRING,
             help="A dict of data files. The key is the dataset name while "
             "the value is a dict containing arguments indicating data files."
         ),
         Flag("data_class",
              dtype=Flag.TYPE.STRING,
              help="The dataset class for the data files."),
         Flag("common_properties",
              dtype=Flag.TYPE.STRING,
              default=None,
              help="Other common properties for building a dataset."),
         ModuleFlag(DataSampler.REGISTRY_NAME,
                    default=None,
                    help="The sampler for unbalanced datasets.")
     ]
示例#15
0
 def class_or_method_args():
     return [
         Flag(
             "tb_log_dir",
             dtype=Flag.TYPE.STRING,
             default=None,
             help=
             "The path to store tensorboard summary, or `model_dir`/validation by default."
         ),
         Flag("waiting_interval",
              dtype=Flag.TYPE.INTEGER,
              default=120,
              help="The waiting interval between two evaluation steps."),
         Flag("maximum_waiting_time",
              dtype=Flag.TYPE.INTEGER,
              default=3600,
              help="The maximum waiting time(in seconds)."),
         ModuleFlag(Validator.REGISTRY_NAME,
                    help="The validation process during training."),
     ]
示例#16
0
 def class_or_method_args():
     return [
         Flag(
             "data_files",
             dtype=Flag.TYPE.STRING,
             help=
             "A dict of parallel data files. The key is the dataset name while "
             "the value is a dict containing `src_file` and `trg_file`."),
         Flag("data_is_processed",
              dtype=Flag.TYPE.BOOLEAN,
              help="Whether the text data is already processed."),
         Flag("src_lang",
              dtype=Flag.TYPE.STRING,
              default=None,
              help="The source language"),
         Flag("trg_lang",
              dtype=Flag.TYPE.STRING,
              default=None,
              help="The target language"),
         ModuleFlag(DataSampler.REGISTRY_NAME,
                    default=None,
                    help="The sampler for unbalanced datasets.")
     ]
示例#17
0
 def class_or_method_args():
     return [
         ModuleFlag(Criterion.REGISTRY_NAME,
                    help="The criterion for evaluation."),
     ]
示例#18
0
 def class_or_method_args():
     return [
         ModuleFlag(Criterion.REGISTRY_NAME,
                    help="The criterion for training or evaluation."),
         ModuleFlag(OPTIMIZER_REGISTRY_NAME,
                    help="The optimizer for training."),
         ModuleFlag(LR_SCHEDULE_REGISTRY_NAME,
                    help="The learning schedule for training."),
         ModuleFlag(Validator.REGISTRY_NAME,
                    help="The validation process while training."),
         Flag(
             "tb_log_dir",
             dtype=Flag.TYPE.STRING,
             default=None,
             help=
             "The path to store tensorboard summary, or `model_dir`/train by default."
         ),
         Flag("train_steps",
              dtype=Flag.TYPE.INTEGER,
              default=10000000,
              help="The maximum steps for training loop."),
         Flag(
             "summary_steps",
             dtype=Flag.TYPE.INTEGER,
             default=200,
             help="Doing summary(logging & tensorboard) this every steps."),
         Flag("save_checkpoint_steps",
              dtype=Flag.TYPE.INTEGER,
              default=1000,
              help="Saving checkpoints this every steps."),
         Flag("checkpoints_max_to_keep",
              dtype=Flag.TYPE.INTEGER,
              default=8,
              help="The maximum checkpoints to be kept."),
         Flag("initial_global_step",
              dtype=Flag.TYPE.INTEGER,
              default=None,
              help="The manually specified initial global step."),
         Flag(
             "pretrain_model",
             dtype=Flag.TYPE.STRING,
             default=None,
             multiple=True,
             help=
             "The path to a pretrained model directory(a seq2seq model, bert model, etc.). "
             "Multiple pretrain models are also available."),
         Flag(
             "pretrain_variable_pattern",
             dtype=Flag.TYPE.STRING,
             default=None,
             multiple=True,
             help=
             "One can restore specified variables in the `pretrain_model` by this regular expression."
             "Multiple pattern are also available, but must match to `pretrain_model`."
         ),
         Flag(
             "update_cycle",
             dtype=Flag.TYPE.INTEGER,
             default=1,
             help=
             "Training step with this many batches (Gradient Accumulation)."
         ),
         Flag("clip_value",
              dtype=Flag.TYPE.FLOAT,
              default=None,
              help="Gradient clipping by value."),
         Flag("clip_norm",
              dtype=Flag.TYPE.FLOAT,
              default=None,
              help="Gradient clipping by norm."),
         Flag(
             "experimental_count_batch_num",
             dtype=Flag.TYPE.BOOLEAN,
             default=None,
             help=
             "Pre-scan the dataset for training and count the number of batches."
         )
     ]