예제 #1
0
class WmtTranslateEnde(wmt.WmtTranslate):
    """WMT English-German translation dataset."""

    BUILDER_CONFIGS = [
        wmt.WMTConfig(language_pair=("en", "de"),
                      version="0.0.1",
                      name_suffix="t2t",
                      data={
                          "train": T2T_ENDE_TRAIN,
                          "test": T2T_ENDE_TEST,
                          "dev": T2T_ENDE_DEV
                      }),
        wmt.WMTConfig(language_pair=("en", "de"),
                      version="0.0.1",
                      text_encoder_config=tfds.features.text.TextEncoderConfig(
                          encoder_cls=tfds.features.text.SubwordTextEncoder,
                          name="subwords8k",
                          vocab_size=2**13),
                      name_suffix="t2t",
                      data={
                          "train": T2T_ENDE_TRAIN,
                          "test": T2T_ENDE_TEST,
                          "dev": T2T_ENDE_DEV
                      }),
    ]

    @property
    def translate_datasets(self):
        return TRANSLATE_DATASETS
예제 #2
0
class WmtTranslateEnfr(wmt.WmtTranslate):
  """English-French WMT translation dataset."""

  BUILDER_CONFIGS = [
      # EN-FR translations (matching the data used by Tensor2Tensor library).
      wmt.WMTConfig(
          language_pair=("en", "fr"),
          version="0.0.2",
          name_suffix="t2t_small",
          data={
              "train": T2T_ENFR_TRAIN_SMALL,
              "dev": T2T_ENFR_DEV_SMALL
          }),
      wmt.WMTConfig(
          language_pair=("en", "fr"),
          version="0.0.2",
          text_encoder_config=tfds.features.text.TextEncoderConfig(
              encoder_cls=tfds.features.text.SubwordTextEncoder,
              name="subwords8k",
              vocab_size=2**13),
          name_suffix="t2t_small",
          data={
              "train": T2T_ENFR_TRAIN_SMALL,
              "dev": T2T_ENFR_DEV_SMALL
          }),
      wmt.WMTConfig(
          language_pair=("en", "fr"),
          version="0.0.2",
          name_suffix="t2t_large",
          data={
              "train": T2T_ENFR_TRAIN_LARGE,
              "dev": T2T_ENFR_DEV_LARGE
          }),
      wmt.WMTConfig(
          language_pair=("en", "fr"),
          version="0.0.2",
          text_encoder_config=tfds.features.text.TextEncoderConfig(
              encoder_cls=tfds.features.text.SubwordTextEncoder,
              name="subwords8k",
              vocab_size=2**13),
          name_suffix="t2t_large",
          data={
              "train": T2T_ENFR_TRAIN_LARGE,
              "dev": T2T_ENFR_DEV_LARGE
          }),
  ]

  @property
  def translate_datasets(self):
    return TRANSLATE_DATASETS