Ejemplo n.º 1
0
 def feature_encoders(self, data_dir):
     if self.is_character_level:
         encoder = text_encoder.ByteTextEncoder()
     else:
         vocab_filename = os.path.join(
             data_dir, "vocab.ende.%d" % self.targeted_vocab_size)
         encoder = text_encoder.SubwordTextEncoder(vocab_filename)
     input_encoder = text_encoder.ImageEncoder(channels=self.num_channels)
     return {"inputs": input_encoder, "targets": encoder}
Ejemplo n.º 2
0
def build_subword_tokenizer(vocab_path):
    encoder = text_encoder.SubwordTextEncoder(vocab_path)

    def encode(x):
        ids = encoder.encode(x)
        subtokens = [encoder._subtoken_id_to_subtoken_string(s) for s in ids]
        return subtokens, ids

    return encode
 def feature_encoders(self, data_dir):
   source_vocab_filename = os.path.join(data_dir, self.source_vocab_name)
   target_vocab_filename = os.path.join(data_dir, self.target_vocab_name)
   source_token = CharacterTextEncoder(source_vocab_filename, replace_oov="UNK")
   target_token = text_encoder.SubwordTextEncoder(target_vocab_filename)
   return {
     "inputs": source_token,
     "targets": target_token,
   }
Ejemplo n.º 4
0
def get_or_generate_vocab(data_dir,
                          tmp_dir,
                          vocab_filename,
                          vocab_size,
                          sources=None):
    """Generate a vocabulary from the datasets in sources (_DATA_FILE_URLS)."""
    vocab_filepath = os.path.join(data_dir, vocab_filename)
    if tf.gfile.Exists(vocab_filepath):
        tf.logging.info("Found vocab file: %s", vocab_filepath)
        vocab = text_encoder.SubwordTextEncoder(vocab_filepath)
        return vocab

    sources = sources or _DATA_FILE_URLS
    tf.logging.info("Generating vocab from: %s", str(sources))
    token_counts = defaultdict(int)
    for source in sources:
        url = source[0]
        filename = os.path.basename(url)
        read_type = "r:gz" if "tgz" in filename else "r"

        compressed_file = maybe_download(tmp_dir, filename, url)

        with tarfile.open(compressed_file, read_type) as corpus_tar:
            corpus_tar.extractall(tmp_dir)

        for lang_file in source[1]:
            tf.logging.info("Reading file: %s" % lang_file)
            filepath = os.path.join(tmp_dir, lang_file)

            # For some datasets a second extraction is necessary.
            if ".gz" in lang_file:
                new_filepath = os.path.join(tmp_dir, lang_file[:-3])
                if tf.gfile.Exists(new_filepath):
                    tf.logging.info(
                        "Subdirectory %s already exists, skipping unpacking" %
                        filepath)
                else:
                    tf.logging.info("Unpacking subdirectory %s" % filepath)
                    gunzip_file(filepath, new_filepath)
                filepath = new_filepath

            # Use Tokenizer to count the word occurrences.
            with tf.gfile.GFile(filepath, mode="r") as source_file:
                file_byte_budget = 3.5e5 if "en" in filepath else 7e5
                for line in source_file:
                    if file_byte_budget <= 0:
                        break
                    line = line.strip()
                    file_byte_budget -= len(line)
                    for tok in tokenizer.encode(
                            text_encoder.native_to_unicode(line)):
                        token_counts[tok] += 1

    vocab = text_encoder.SubwordTextEncoder.build_to_target_size(
        vocab_size, token_counts, 1, 1e3)
    vocab.store_to_file(vocab_filepath)
    return vocab
Ejemplo n.º 5
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tokenizer = FLAGS.tokenizer.value()
    subword_tokenizer = text_encoder.SubwordTextEncoder(
        FLAGS.vocabulary_filepath)

    files = glob.glob(FLAGS.input_filepath + '/**/*.py', recursive=True)
    examples = []
    unique_id = 0
    skipped = 0
    for fname in tqdm(files, desc="Extracting methods"):
        try:
            with open(fname) as fh:
                root = ast.parse(fh.read(), fname)
        except Exception as e:
            skipped += 1
            if args.verbose:
                print(f"Skipping problematic file {e}", fname, file=sys.stderr)
            continue

        # Get label from parent folder of file
        label = fname.split('/')[5]

        # Only consider methods
        for node in ast.iter_child_nodes(root):
            if isinstance(node, ast.FunctionDef):
                method_name = node.name
                method_string = astunparse.unparse(node)
                method_body = astunparse.unparse(node.body)
                examples.append({
                    "unique_id": unique_id,
                    "filepath": fname,
                    "label": label,
                    "method_name": method_name,
                    "method_string": method_string,
                    "method_body": method_body
                })
                unique_id += 1

    print(
        f"DONE WITH EXTRACTION. SKIPPED {skipped} PROBLEMATIC FILES IN TOTAL")

    with open(FLAGS.output_filepath, 'w') as csvfile:
        writer = csv.DictWriter(csvfile,
                                fieldnames=[
                                    'unique_ids', 'filepath', 'label',
                                    'method_name', 'tokens', 'input_ids',
                                    'input_mask', 'input_type_ids'
                                ])
        writer.writeheader()
        for sample in tqdm(
                examples,
                desc=f"Writing results to file {FLAGS.output_filepath}"):
            tokens = tokenize(sample, tokenizer, subword_tokenizer)
            writer.writerow(tokens)
Ejemplo n.º 6
0
 def TokenEncoder(self, vocab_file_path):
     '''
     if self.token_type == 'word_piece_sequence':
         return text_encoder.SubwordTextEncoder(vocab_file_path)
     else:
         return text_encoder.TokenTextEncoder(
             vocab_file_path, replace_oov=OOV_token)
     '''
     return text_encoder.SubwordTextEncoder(vocab_file_path)
Ejemplo n.º 7
0
def _get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size):
    """Read or create vocabulary."""
    vocab_filepath = os.path.join(tmp_dir, vocab_filename)
    print('Vocab file written to: ' + vocab_filepath)

    if tf.gfile.Exists(vocab_filepath):
        gs = text_encoder.SubwordTextEncoder(vocab_filepath)
        return gs
    example_file = os.path.join(tmp_dir, _EXAMPLES_FILE)
    gs = text_encoder.SubwordTextEncoder()
    token_counts = tokenizer.corpus_token_counts(example_file,
                                                 corpus_max_lines=1000000)
    gs = gs.build_to_target_size(vocab_size,
                                 token_counts,
                                 min_val=1,
                                 max_val=1e3)
    gs.store_to_file(vocab_filepath)
    return gs
Ejemplo n.º 8
0
def make_or_load_limited_size_subword_tokenizer(
    vocab_path, token_counts_untruncated, truncation_length, target_vocab_size):
  """Makes or loads a tokenizer depending on whether the vocab file exists.

  Also saves the vocab if the file does not exist. This function differs from
  the `make_or_load_subword_tokenizer` function above in three ways:
  (1) Token counts are provided directly as a dict, rather than obtained by
      counting in a sequence of texts.
  (2) The tokens are truncated to a maximum length `truncation_length`, and the
      counts of tokens mapping onto the same truncated token are summed.
  (3) The optional parameter `max_subtoken_length` is passed to the method
      `build_to_target_size` of SubwordTextEncoder, to avoid the quadratic
      in vocabulary size cost (both time and memory).

  Args:
    vocab_path: Full vocab path to save to and/or load from. If None, just
      create a tokenizer without loading or saving.
    token_counts_untruncated: Dict {(untruncated) token: count}.
    truncation_length: Integer specifying maximum token length after truncation.
    target_vocab_size: Target size of the vocabulary.

  Returns:
    A text_encoder.SubwordTextEncoder.
  """
  if vocab_path is not None and tf.gfile.Exists(vocab_path):
    tf.logging.info("Using vocab from '%s'", vocab_path)
    return text_encoder.SubwordTextEncoder(vocab_path)

  # Truncate tokens to maximum length NODE_TEXT_TRUNCATION_LENGTH:
  token_counts = collections.defaultdict(int)
  for token in token_counts_untruncated:
    token_truncated = token[:truncation_length]
    token_counts[token_truncated] += token_counts_untruncated[token]
  tf.logging.info("Vocabulary size reduced from %d to %d due to truncation." % (
      len(token_counts_untruncated), len(token_counts)))

  # The subword tokenizer searches over minimum subtoken counts to create
  # vocab items for, aiming to get a vocab size near `target_vocab_size`.
  # We need to provide initial lower and upper bounds to initialize the binary
  # search.
  # TODO(matejb): This might be tweaked in the future (see corresponding TODO
  # in the `make_or_load_subword_tokenizer` function above).
  token_count_binary_search_lower_bound = 1
  token_count_binary_search_upper_bound = len(token_counts) // 2
  tokenizer = text_encoder.SubwordTextEncoder.build_to_target_size(
      target_vocab_size, token_counts,
      token_count_binary_search_lower_bound,
      token_count_binary_search_upper_bound,
      max_subtoken_length=truncation_length,
      reserved_tokens=text_encoder.RESERVED_TOKENS)

  if vocab_path is not None:
    tf.logging.info("Writing vocab to '%s'", vocab_path)
    tokenizer.store_to_file(vocab_path)

  return tokenizer
Ejemplo n.º 9
0
 def __init__(self, data_dict):
     self.encoder = text_encoder.SubwordTextEncoder()
     self.encoder._load_from_file('./lm_subword_text_encoder2')
     #self.dictionary = SequenceVocabulary()
     self.train = torch.tensor(self.encoder.encode(
         data_dict['train'])).type(torch.int64)
     self.valid = torch.tensor(self.encoder.encode(data_dict['val'])).type(
         torch.int64)
     self.test = torch.tensor(self.encoder.encode(data_dict['test'])).type(
         torch.int64)
Ejemplo n.º 10
0
def wmt_ende_v2(model_hparams, vocab_size):
  """English to German translation benchmark with separate vocabularies."""
  p = default_problem_hparams()
  # These vocab files must be present within the data directory.
  source_vocab_filename = os.path.join(model_hparams.data_dir,
                                       "wmt_ende_v2.en.vocab.%d" % vocab_size)
  target_vocab_filename = os.path.join(model_hparams.data_dir,
                                       "wmt_ende_v2.de.vocab.%d" % vocab_size)
  p.input_modality = {
      "inputs": modality.SymbolModality(model_hparams, vocab_size)
  }
  p.target_modality = modality.SymbolModality(model_hparams, vocab_size)
  p.vocabulary = {
      "inputs": text_encoder.SubwordTextEncoder(source_vocab_filename),
      "targets": text_encoder.SubwordTextEncoder(target_vocab_filename),
  }
  p.input_space_id = 3
  p.target_space_id = 8
  return p
Ejemplo n.º 11
0
def main(unused_argv):
  gs = text_encoder.SubwordTextEncoder()
  if not FLAGS.corpus_filepattern:
    raise ValueError('Must provide --corpus_filepattern')
  token_counts = text_encoder.SubwordTextEncoder.get_token_counts(
      FLAGS.corpus_filepattern, FLAGS.corpus_max_lines)
  gs.build_from_token_counts(token_counts,
                             FLAGS.min_count,
                             FLAGS.num_iterations)
  gs.store_to_file(FLAGS.output_fn)
Ejemplo n.º 12
0
 def feature_encoders(self, data_dir):
   subword_encoder = text_encoder.SubwordTextEncoder(
       os.path.join(data_dir, self.vocab_file))
   error_tag_encoder = text_encoder.TokenTextEncoder(
       os.path.join(data_dir, self.error_tag_vocab_file))
   return {
       "inputs": subword_encoder,
       "targets": subword_encoder,
       "targets_error_tag": error_tag_encoder
   }
def CleanWordpieceEncoderNumbers(encoder,
                                 add_end_token=False):  # (rb) unused function?
    wordpiece_encoder = text_encoder.SubwordTextEncoder()
    wordpieces = GetCleanedWordpieceList(encoder)
    if add_end_token:
        wordpieces.append('N_')
    wordpieces = ['<pad>', '<EOS>'] + wordpieces
    wordpiece_encoder._init_subtokens_from_list(wordpieces)
    wordpiece_encoder._init_alphabet_from_tokens(wordpieces)
    return wordpiece_encoder
Ejemplo n.º 14
0
def main(_):
    """Convert a file to examples."""
    if FLAGS.subword_text_encoder_filename:
        encoder = text_encoder.SubwordTextEncoder(
            FLAGS.subword_text_encoder_filename)
    elif FLAGS.token_text_encoder_filename:
        encoder = text_encoder.TokenTextEncoder(
            FLAGS.token_text_encoder_filename)
    elif FLAGS.byte_text_encoder:
        encoder = text_encoder.ByteTextEncoder()
    elif FLAGS.byte_pair_encoder:
        encoder = text_encoder.BytePairEncoder(FLAGS.bpe_encoder_file,
                                               FLAGS.bpe_vocab_file)
    else:
        encoder = None
    reader = tf.python_io.tf_record_iterator(FLAGS.input_filename)
    total_sequences = 0
    total_input_tokens = 0
    total_target_tokens = 0
    nonpadding_input_tokens = 0
    nonpadding_target_tokens = 0
    max_input_length = 0
    max_target_length = 0
    for record in reader:
        x = tf.train.Example()
        x.ParseFromString(record)
        inputs = [
            int(i) for i in x.features.feature["inputs"].int64_list.value
        ]
        targets = [
            int(i) for i in x.features.feature["targets"].int64_list.value
        ]
        if FLAGS.print_inputs:
            print("INPUTS:\n" + encoder.decode(inputs) if encoder else inputs)
        if FLAGS.print_targets:
            print("TARGETS:\n" +
                  encoder.decode(targets) if encoder else targets)
        nonpadding_input_tokens += len(inputs) - inputs.count(0)
        nonpadding_target_tokens += len(targets) - targets.count(0)
        total_input_tokens += len(inputs)
        total_target_tokens += len(targets)
        total_sequences += 1
        max_input_length = max(max_input_length, len(inputs))
        max_target_length = max(max_target_length, len(targets))
        if FLAGS.print_all:
            for k, v in six.iteritems(x.features.feature):
                print("%s: %s" % (k, v.int64_list.value))

    print("total_sequences: %d" % total_sequences)
    print("total_input_tokens: %d" % total_input_tokens)
    print("total_target_tokens: %d" % total_target_tokens)
    print("nonpadding_input_tokens: %d" % nonpadding_input_tokens)
    print("nonpadding_target_tokens: %d" % nonpadding_target_tokens)
    print("max_input_length: %d" % max_input_length)
    print("max_target_length: %d" % max_target_length)
Ejemplo n.º 15
0
def wiki_32k(model_hparams):
    """Wikipedia title to article.  32k subtoken vocabulary."""
    p = default_problem_hparams()
    encoder = text_encoder.SubwordTextEncoder(
        os.path.join(model_hparams.data_dir, "wiki_32k.subword_text_encoder"))
    modality_spec = (registry.Modalities.SYMBOL, encoder.vocab_size)
    p.input_modality = {"inputs": modality_spec}
    p.target_modality = modality_spec
    p.vocabulary = {"inputs": encoder, "targets": encoder}
    p.target_space_id = 3
    return p
Ejemplo n.º 16
0
def main(_):
  """Convert a file to examples."""
  subtokenizer = text_encoder.SubwordTextEncoder(FLAGS.vocab_file)
  reader = tf.python_io.tf_record_iterator(FLAGS.in_file)
  for record in reader:
    x = tf.train.Example()
    x.ParseFromString(record)
    inputs = [int(i) for i in x.features.feature["inputs"].int64_list.value]
    targets = [int(i) for i in x.features.feature["targets"].int64_list.value]
    ShowSequence(subtokenizer, inputs, "inputs")
    ShowSequence(subtokenizer, targets, "targets")
Ejemplo n.º 17
0
 def get_vocab():
   """Get vocab for caption text encoder."""
   if data_dir is not None and vocab_filename is not None:
     vocab_filepath = os.path.join(data_dir, vocab_filename)
     if tf.gfile.Exists(vocab_filepath):
       tf.logging.info("Found vocab file: %s", vocab_filepath)
       vocab_symbolizer = text_encoder.SubwordTextEncoder(vocab_filepath)
       return vocab_symbolizer
     else:
       raise ValueError("Vocab file does not exist: %s", vocab_filepath)
   return None
Ejemplo n.º 18
0
 def feature_encoders(self, data_dir):
     if self.is_character_level:
         encoder = text_encoder.ByteTextEncoder()
     elif self.use_subword_tokenizer:
         vocab_filename = os.path.join(data_dir, self.vocab_file)
         encoder = text_encoder.SubwordTextEncoder(vocab_filename)
     else:
         vocab_filename = os.path.join(data_dir, self.vocab_file)
         encoder = text_encoder.TokenTextEncoder(vocab_filename)
     if self.has_inputs:
         return {"inputs": encoder, "targets": encoder}
     return {"targets": encoder}
Ejemplo n.º 19
0
def wmt_concat(model_hparams, wrong_vocab_size):
    """English to German translation benchmark."""
    p = default_problem_hparams()
    # This vocab file must be present within the data directory.
    vocab_filename = os.path.join(model_hparams.data_dir,
                                  "tokens.vocab.%d" % wrong_vocab_size)
    subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename)
    vocab_size = subtokenizer.vocab_size
    p.input_modality = {}
    p.target_modality = (registry.Modalities.SYMBOL, vocab_size)
    p.vocabulary = {"targets": subtokenizer}
    return p
Ejemplo n.º 20
0
 def feature_encoders(self, data_dir):
     '''Used on inference, convert input and output from ids to tokens.
 The returned results are stored in self._encoders
 '''
     vocab_filename = os.path.join(data_dir, self.vocab_file)
     encoder = text_encoder.SubwordTextEncoder(vocab_filename)
     return {
         "inputs":
         encoder,
         "targets":
         text_encoder.ClassLabelEncoder(class_labels_fname=LABEL_FILE),
     }
Ejemplo n.º 21
0
def wsj_parsing_tokens(model_hparams, prefix, wrong_source_vocab_size,
                       wrong_target_vocab_size):
    """English to parse tree translation benchmark.

  Args:
    model_hparams: a tf.contrib.training.HParams
    prefix: name to use as prefix for vocabulary files.
    wrong_source_vocab_size: a number used in the filename indicating the
      approximate vocabulary size.  This is not to be confused with the actual
      vocabulary size.
    wrong_target_vocab_size: a number used in the filename indicating the
      approximate target vocabulary size. This is not to be confused with the
      actual target vocabulary size.
  Returns:
    a tf.contrib.training.HParams
  """
    p = default_problem_hparams()
    # This vocab file must be present within the data directory.
    source_vocab_filename = os.path.join(
        model_hparams.data_dir,
        prefix + "_source.vocab.%d" % wrong_source_vocab_size)
    target_vocab_filename = os.path.join(
        model_hparams.data_dir,
        prefix + "_target.vocab.%d" % wrong_target_vocab_size)
    source_subtokenizer = text_encoder.SubwordTextEncoder(
        source_vocab_filename)
    target_subtokenizer = text_encoder.SubwordTextEncoder(
        target_vocab_filename)
    p.input_modality = {
        "inputs": (registry.Modalities.SYMBOL, source_subtokenizer.vocab_size)
    }
    p.target_modality = (registry.Modalities.SYMBOL,
                         target_subtokenizer.vocab_size)
    p.vocabulary = {
        "inputs": source_subtokenizer,
        "targets": target_subtokenizer,
    }
    p.input_space_id = 3
    p.target_space_id = 15
    return p
Ejemplo n.º 22
0
    def __init__(self, filepath):
        """Create a T2tVocabulary.

    Args:
      filepath: a string
    """
        # Only import tensor2tensor if necessary.
        from tensor2tensor.data_generators import text_encoder  # pylint: disable=g-import-not-at-top
        # minxu
        #from tensor2tensor.data_generators.ops import subword_text_encoder_ops   # pylint: disable=g-import-not-at-top

        self._filepath = filepath
        self._subword_text_encoder = text_encoder.SubwordTextEncoder(filepath)
Ejemplo n.º 23
0
def main(argv):
    tokenizer = FLAGS.tokenizer.value()
    sub_word_tokenizer = text_encoder.SubwordTextEncoder(
        FLAGS.vocabulary_filepath)

    input_dir = Path(FLAGS.input_dir)
    output_dir = Path(FLAGS.output_dir)

    dataset_tokenizer = MMRDatasetTokenizer(input_dir, output_dir, tokenizer,
                                            sub_word_tokenizer)
    with Pool() as pool:
        pool.map(dataset_tokenizer.tokenize_project,
                 dataset_tokenizer.projects_list)
def main(_):
    global s_words, t_words, m_words, s_subws, t_subws, m_subws, sents
    vocab_src = text_encoder.SubwordTextEncoder(FLAGS.vocab_src)
    vocab_trg = text_encoder.SubwordTextEncoder(FLAGS.vocab_trg)
    with open(FLAGS.src, encoding="utf-8") as src, open(FLAGS.trg, encoding="utf-8") as trg:
        for s, t in zip(src, trg):
            sents += 1
            s = s.strip()
            t = t.strip()
            s_w, s_s = words_subwords(vocab_src, s)
            t_w, t_s = words_subwords(vocab_trg, t)
            s_words += s_w
            t_words += t_w
            m_words += max(s_w, t_w)
            s_subws += s_s
            t_subws += t_s
            m_subws += max(s_s, t_s)
            if sents % 100000 == 0:
                print_stats()
            if FLAGS.print:
                print("a" * max(s_s, t_s))
    print_stats()
Ejemplo n.º 25
0
def lm1b_32k(model_hparams):
    """Billion-word language-modeling benchmark, 32k subword vocabulary."""
    p = default_problem_hparams()
    # ratio of dev tokens (including eos) to dev words (including eos)
    # 176884 / 159658 = 1.107893
    p.perplexity_exponent = 1.107893
    p.input_modality = {}
    encoder = text_encoder.SubwordTextEncoder(
        os.path.join(model_hparams.data_dir, "lm1b_32k.subword_text_encoder"))
    p.target_modality = (registry.Modalities.SYMBOL, encoder.vocab_size)
    p.vocabulary = {"targets": encoder}
    p.target_space_id = 3
    return p
Ejemplo n.º 26
0
def lm1b_64k(model_hparams):
    """Billion-word language-modeling benchmark, 64k subtoken vocabulary."""
    p = default_problem_hparams()
    p.perplexity_exponent = 1.067068
    p.input_modality = {}
    p.target_modality = (registry.Modalities.SYMBOL, 65536)
    p.vocabulary = {
        "targets":
        text_encoder.SubwordTextEncoder(
            os.path.join(model_hparams.data_dir,
                         "lm1b_64k.subword_text_encoder"))
    }
    p.target_space_id = 3
    return p
Ejemplo n.º 27
0
def lm1b_16k(model_hparams):
  """Billion-word language-modeling benchmark, 16k subtoken vocabulary."""
  p = default_problem_hparams()
  p.perplexity_exponent = 1.184206
  p.input_modality = {}
  p.target_modality = modality.SymbolModality(model_hparams, 16384)
  p.vocabulary = {
      "targets":
          text_encoder.SubwordTextEncoder(
              os.path.join(model_hparams.data_dir,
                           "lm1b_16k.subword_text_encoder"))
  }
  p.target_space_id = 3
  return p
Ejemplo n.º 28
0
    def generator(self, data_dir, tmp_dir, train):
        """Instance of token generator for AAER training set."""

        token_path = os.path.join(const.T2T_DATA_DIR,
                                  const.T2T_AAER_VOLCAB_NAME)

        with tf.gfile.GFile(token_path, mode="a") as f:
            f.write("UNK\n")  # Add UNK to the vocab.
        token_vocab = text_encoder.SubwordTextEncoder(token_path)

        source_path = const.T2T_AAER_SOURCE_PATH  # if train else const.T2T_AAER_SOURCE_PATH + const.T2T_EVAL_POST_FIX
        targets_path = const.T2T_AAER_TARGETS_PATH  # if train else const.T2T_AAER_TARGETS_PATH + const.T2T_EVAL_POST_FIX

        return token_generator(source_path, targets_path, token_vocab, EOS)
    def feature_encoders(self, data_dir):
        source_vocab_filename = os.path.join(data_dir, self.source_vocab_name)
        target_vocab_filename = os.path.join(data_dir, self.target_vocab_name)
        source_token = text_encoder.SubwordTextEncoder(source_vocab_filename)
        target_token = text_encoder.SubwordTextEncoder(target_vocab_filename)
        return {
            "inputs": source_token,
            "targets": target_token,
        }


# @registry.register_problem
# class TranslateEnzhWmt8k(TranslateEnzhWmt32k):
#   """Problem spec for WMT En-Zh translation.
#   This is far from being the real WMT17 task - only toyset here
#   """
#
#   @property
#   def approx_vocab_size(self):
#     return 2**13  # 8192
#
#   @property
#   def dataset_splits(self):
#     return [
#         {
#             "split": problem.DatasetSplit.TRAIN,
#             "shards": 10,  # this is a small dataset
#         },
#         {
#             "split": problem.DatasetSplit.EVAL,
#             "shards": 1,
#         }
#     ]
#
#   def get_training_dataset(self, tmp_dir):
#     """Uses only News Commentary Dataset for training."""
#     return _NC_TRAIN_DATASETS
Ejemplo n.º 30
0
def image_mscoco_tokens(model_hparams, vocab_count):
    """COCO image captioning with captions as tokens."""
    p = default_problem_hparams()
    p.input_modality = {"inputs": (registry.Modalities.IMAGE, None)}
    # This vocab file must be present within the data directory.
    vocab_filename = os.path.join(model_hparams.data_dir,
                                  "vocab.endefr.%d" % vocab_count)
    subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename)
    p.target_modality = (registry.Modalities.SYMBOL, subtokenizer.vocab_size)
    p.vocabulary = {
        "inputs": text_encoder.TextEncoder(),
        "targets": subtokenizer,
    }
    p.batch_size_multiplier = 256
    p.max_expected_batch_size_per_shard = 2