Example #1
0
  def test_bow2text_dataset(self):
    tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE)
    graph_tokenizer = tokenizers.GraphTokenizer(vocab_file=GRAPH_VOCAB_FILE)

    batch_size = 4
    seq_len = 256
    dataset = paired_dataset.Bow2TextDataset(
        tokenizer,
        graph_tokenizer,
        batch_size=batch_size,
        timesteps=seq_len,
        subset='valid',
        subsample_nodes=0.7,
        repeat=False,
        data_dir=WIKIGRAPHS_ROOT)

    num_tokens = 0
    for batch in dataset:
      num_tokens += batch['mask'].sum()
      self.assertEqual(batch['graphs'].shape,
                       (batch_size, graph_tokenizer.vocab_size))

    raw_dataset = paired_dataset.RawDataset(subset='valid', shuffle_data=False)
    raw_num_tokens = 0
    n_pairs = 0
    for pair in raw_dataset:
      raw_num_tokens += len(tokenizer.encode(
          pair.text, prepend_bos=True, append_eos=True))
      n_pairs += 1

    # The first token of each example is not counted by `mask` as it masks the
    # targets, and the first token of each example never appears in the targets.
    self.assertEqual(raw_num_tokens, num_tokens + n_pairs)
Example #2
0
    def test_wikitext_dataset_size(self):
        tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE)
        batch_size = 4
        timesteps = 256
        valid_set = wikitext.WikitextDataset(tokenizer=tokenizer,
                                             batch_size=batch_size,
                                             timesteps=timesteps,
                                             subset='valid',
                                             shuffle_data=False,
                                             repeat=False,
                                             data_dir=WIKITEXT_ROOT)
        n_tokens = 0
        n_bos = 0
        for batch in valid_set:
            n_tokens += (batch['obs'] != tokenizer.pad_token()).sum()
            n_bos += (batch['obs'] == tokenizer.bos_token()).sum()
            self.assertEqual(batch['obs'].shape, (batch_size, timesteps))
            self.assertEqual(batch['target'].shape, (batch_size, timesteps))
            self.assertEqual(batch['should_reset'].shape,
                             (batch_size, timesteps))
            self.assertEqual(batch['mask'].shape, (batch_size, timesteps))

        n_tokens -= n_bos
        self.assertEqual(n_tokens, 217646)
        self.assertEqual(n_bos, 60)
  def test_tokenizer(self):
    tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE)
    # Vocab size must match published number.
    self.assertEqual(tokenizer.vocab_size, 267735 + 2)

    s = 'Hello world ! \n How are you ?'
    encoded = tokenizer.encode(s, prepend_bos=True)
    self.assertEqual(encoded.shape, (9,))
    decoded = tokenizer.decode(encoded)
    self.assertEqual(s, decoded)
Example #4
0
  def test_graph2text_dataset(self):
    tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE)
    graph_tokenizer = tokenizers.GraphTokenizer(vocab_file=GRAPH_VOCAB_FILE)

    batch_size = 4
    seq_len = 256
    dataset = paired_dataset.Graph2TextDataset(
        tokenizer,
        graph_tokenizer,
        batch_size=batch_size,
        timesteps=seq_len,
        subsample_nodes=0.8,
        subset='valid',
        data_dir=WIKIGRAPHS_ROOT)
    data_iter = iter(dataset)
    batch = next(data_iter)
    self.assertEqual(batch['obs'].shape, (batch_size, seq_len))
    self.assertEqual(batch['target'].shape, (batch_size, seq_len))
    self.assertEqual(batch['should_reset'].shape, (batch_size, seq_len))
    self.assertEqual(batch['mask'].shape, (batch_size, seq_len))
    self.assertIsInstance(batch['graphs'], list)
    self.assertLen(batch['graphs'], batch_size)
    for i in range(batch_size):
      self.assertIsInstance(batch['graphs'][i], jraph.GraphsTuple)

      # +1 for the center_node mask
      self.assertEqual(
          batch['graphs'][i].nodes.shape[-1], graph_tokenizer.vocab_size + 1)
      self.assertEqual(
          batch['graphs'][i].edges.shape[-1], graph_tokenizer.vocab_size)
      n_edges = batch['graphs'][i].n_edge
      self.assertEqual(batch['graphs'][i].senders.shape, (n_edges,))
      self.assertEqual(batch['graphs'][i].receivers.shape, (n_edges,))

    # Make sure the token count matches across the tokenized data and the raw
    # data set.
    num_tokens = 0
    for batch in dataset:
      num_tokens += batch['mask'].sum()

    raw_dataset = paired_dataset.RawDataset(subset='valid', shuffle_data=False)
    raw_num_tokens = 0
    n_pairs = 0
    for pair in raw_dataset:
      raw_num_tokens += len(tokenizer.encode(
          pair.text, prepend_bos=True, append_eos=True))
      n_pairs += 1

    # The first token of each example is not counted by `mask` as it masks the
    # targets, and the first token of each example never appears in the targets.
    self.assertEqual(raw_num_tokens, num_tokens + n_pairs)
Example #5
0
  def test_text_only_dataset(self):
    tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE)

    batch_size = 4
    seq_len = 256
    dataset = paired_dataset.TextOnlyDataset(
        tokenizer,
        batch_size=batch_size,
        timesteps=seq_len,
        subset='valid',
        data_dir=WIKIGRAPHS_ROOT)
    data_iter = iter(dataset)
    batch = next(data_iter)
    faux_batch = dataset.return_faux_batch()

    self.assertCountEqual(list(batch.keys()),
                          ['obs', 'target', 'should_reset', 'mask'])
    self.assertCountEqual(list(faux_batch.keys()),
                          ['obs', 'target', 'should_reset', 'mask'])
    for k, v in batch.items():
      faux_v = faux_batch[k]
      self.assertEqual(v.shape, faux_v.shape)
      self.assertEqual(v.dtype, faux_v.dtype)

    self.assertEqual(batch['obs'].shape, (batch_size, seq_len))
    self.assertEqual(batch['target'].shape, (batch_size, seq_len))
    self.assertEqual(batch['should_reset'].shape, (batch_size, seq_len))
    self.assertEqual(batch['mask'].shape, (batch_size, seq_len))

    num_tokens = 0
    for batch in dataset:
      num_tokens += batch['mask'].sum()

    raw_dataset = paired_dataset.RawDataset(subset='valid', shuffle_data=False)
    raw_num_tokens = 0
    n_pairs = 0
    for pair in raw_dataset:
      raw_num_tokens += len(tokenizer.encode(
          pair.text, prepend_bos=True, append_eos=True))
      n_pairs += 1
    self.assertEqual(num_tokens + n_pairs, raw_num_tokens)
Example #6
0
  def test_graph_retrieval_dataset(self):
    tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE)
    graph_tokenizer = tokenizers.GraphTokenizer(vocab_file=GRAPH_VOCAB_FILE)

    batch_size = 4
    seq_len = 256
    dataset = paired_dataset.Graph2TextDataset(
        tokenizer,
        graph_tokenizer,
        batch_size=batch_size,
        timesteps=seq_len,
        subsample_nodes=0.8,
        graph_retrieval_dataset=True,
        subset='valid',
        data_dir=WIKIGRAPHS_ROOT)
    data_iter = iter(dataset)
    batch = next(data_iter)

    self.assertEqual(batch['obs'].shape, (batch_size, seq_len))
    self.assertEqual(batch['target'].shape, (batch_size, seq_len))
    self.assertEqual(batch['should_reset'].shape, (batch_size, seq_len))
    self.assertEqual(batch['mask'].shape, (batch_size, seq_len))
    self.assertEqual(batch['graph_id'].shape, (batch_size,))
    self.assertEqual(batch['seq_id'].shape, (batch_size,))
Example #7
0
def init_tokenizer(dataset_name):
    """Initialie the tokenizer."""
    logging.info('Loading tokenizer...')
    tokenizer = tokenizers.WordTokenizer(VOCAB_FILES_MAP[dataset_name])
    logging.info('Vocab size: %d', tokenizer.vocab_size)
    return tokenizer