def __init__(self, args): """ Initializes the task. Args: args: A dict of model configurations. """ src_data_pipeline_cls = args.get("src_data_pipeline.class", TextDataPipeline) src_data_pipeline_params = args.get("src_data_pipeline.params", None) or {} self._src_data_pipeline = build_data_pipeline( src_data_pipeline_cls, **src_data_pipeline_params) trg_data_pipeline_cls = args.get("trg_data_pipeline.class", TextDataPipeline) trg_data_pipeline_params = args.get("trg_data_pipeline.params", None) or {} self._trg_data_pipeline = build_data_pipeline( trg_data_pipeline_cls, **trg_data_pipeline_params) super(Seq2Seq, self).__init__(args)
def __init__(self, args): """ Initializes with configuration. """ super(MultiTaskSpeechTranslation, self).__init__(args) transcript_dp_cls = args.get("transcript_data_pipeline.class", TranscriptDataPipeline) transcript_dp_params = args.get("transcript_data_pipeline.params", None) or {} self._transcript_data_pipeline = build_data_pipeline( transcript_dp_cls, **transcript_dp_params) translation_dp_cls = args.get("translation_data_pipeline.class", TranscriptDataPipeline) translation_dp_params = args.get("translation_data_pipeline.params", None) or {} self._translation_data_pipeline = build_data_pipeline( translation_dp_cls, **translation_dp_params)
def __init__(self, args): """ Initializes the task. Args: args: A dict of model configurations. """ data_pipeline_cls = args.get("data_pipeline.class", TextDataPipeline) data_pipeline_params = args.get("data_pipeline.params", None) or {} self._data_pipeline = build_data_pipeline(data_pipeline_cls, **data_pipeline_params) self._begin_of_sentence = args.get("begin_of_sentence", "bos") super(LanguageModel, self).__init__(args)
def __init__(self, args): """ Initializes the task. Args: args: A dict of model configurations. """ super(SpeechToText, self).__init__(args) trg_data_pipeline_cls = args.get("transcript_data_pipeline.class", TranscriptDataPipeline) trg_data_pipeline_params = args.get("transcript_data_pipeline.params", None) or {} self._trg_data_pipeline = build_data_pipeline( trg_data_pipeline_cls, **trg_data_pipeline_params) self._audio_feature_dim = args["audio_feature_dim"] self._audio_feature_channels = args["audio_feature_channels"] self._specaug = SpecAugment.build(args.get("specaug", None))