def test_bow2text_dataset(self): tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE) graph_tokenizer = tokenizers.GraphTokenizer(vocab_file=GRAPH_VOCAB_FILE) batch_size = 4 seq_len = 256 dataset = paired_dataset.Bow2TextDataset( tokenizer, graph_tokenizer, batch_size=batch_size, timesteps=seq_len, subset='valid', subsample_nodes=0.7, repeat=False, data_dir=WIKIGRAPHS_ROOT) num_tokens = 0 for batch in dataset: num_tokens += batch['mask'].sum() self.assertEqual(batch['graphs'].shape, (batch_size, graph_tokenizer.vocab_size)) raw_dataset = paired_dataset.RawDataset(subset='valid', shuffle_data=False) raw_num_tokens = 0 n_pairs = 0 for pair in raw_dataset: raw_num_tokens += len(tokenizer.encode( pair.text, prepend_bos=True, append_eos=True)) n_pairs += 1 # The first token of each example is not counted by `mask` as it masks the # targets, and the first token of each example never appears in the targets. self.assertEqual(raw_num_tokens, num_tokens + n_pairs)
def test_wikitext_dataset_size(self): tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE) batch_size = 4 timesteps = 256 valid_set = wikitext.WikitextDataset(tokenizer=tokenizer, batch_size=batch_size, timesteps=timesteps, subset='valid', shuffle_data=False, repeat=False, data_dir=WIKITEXT_ROOT) n_tokens = 0 n_bos = 0 for batch in valid_set: n_tokens += (batch['obs'] != tokenizer.pad_token()).sum() n_bos += (batch['obs'] == tokenizer.bos_token()).sum() self.assertEqual(batch['obs'].shape, (batch_size, timesteps)) self.assertEqual(batch['target'].shape, (batch_size, timesteps)) self.assertEqual(batch['should_reset'].shape, (batch_size, timesteps)) self.assertEqual(batch['mask'].shape, (batch_size, timesteps)) n_tokens -= n_bos self.assertEqual(n_tokens, 217646) self.assertEqual(n_bos, 60)
def test_tokenizer(self): tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE) # Vocab size must match published number. self.assertEqual(tokenizer.vocab_size, 267735 + 2) s = 'Hello world ! \n How are you ?' encoded = tokenizer.encode(s, prepend_bos=True) self.assertEqual(encoded.shape, (9,)) decoded = tokenizer.decode(encoded) self.assertEqual(s, decoded)
def test_graph2text_dataset(self): tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE) graph_tokenizer = tokenizers.GraphTokenizer(vocab_file=GRAPH_VOCAB_FILE) batch_size = 4 seq_len = 256 dataset = paired_dataset.Graph2TextDataset( tokenizer, graph_tokenizer, batch_size=batch_size, timesteps=seq_len, subsample_nodes=0.8, subset='valid', data_dir=WIKIGRAPHS_ROOT) data_iter = iter(dataset) batch = next(data_iter) self.assertEqual(batch['obs'].shape, (batch_size, seq_len)) self.assertEqual(batch['target'].shape, (batch_size, seq_len)) self.assertEqual(batch['should_reset'].shape, (batch_size, seq_len)) self.assertEqual(batch['mask'].shape, (batch_size, seq_len)) self.assertIsInstance(batch['graphs'], list) self.assertLen(batch['graphs'], batch_size) for i in range(batch_size): self.assertIsInstance(batch['graphs'][i], jraph.GraphsTuple) # +1 for the center_node mask self.assertEqual( batch['graphs'][i].nodes.shape[-1], graph_tokenizer.vocab_size + 1) self.assertEqual( batch['graphs'][i].edges.shape[-1], graph_tokenizer.vocab_size) n_edges = batch['graphs'][i].n_edge self.assertEqual(batch['graphs'][i].senders.shape, (n_edges,)) self.assertEqual(batch['graphs'][i].receivers.shape, (n_edges,)) # Make sure the token count matches across the tokenized data and the raw # data set. num_tokens = 0 for batch in dataset: num_tokens += batch['mask'].sum() raw_dataset = paired_dataset.RawDataset(subset='valid', shuffle_data=False) raw_num_tokens = 0 n_pairs = 0 for pair in raw_dataset: raw_num_tokens += len(tokenizer.encode( pair.text, prepend_bos=True, append_eos=True)) n_pairs += 1 # The first token of each example is not counted by `mask` as it masks the # targets, and the first token of each example never appears in the targets. self.assertEqual(raw_num_tokens, num_tokens + n_pairs)
def test_text_only_dataset(self): tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE) batch_size = 4 seq_len = 256 dataset = paired_dataset.TextOnlyDataset( tokenizer, batch_size=batch_size, timesteps=seq_len, subset='valid', data_dir=WIKIGRAPHS_ROOT) data_iter = iter(dataset) batch = next(data_iter) faux_batch = dataset.return_faux_batch() self.assertCountEqual(list(batch.keys()), ['obs', 'target', 'should_reset', 'mask']) self.assertCountEqual(list(faux_batch.keys()), ['obs', 'target', 'should_reset', 'mask']) for k, v in batch.items(): faux_v = faux_batch[k] self.assertEqual(v.shape, faux_v.shape) self.assertEqual(v.dtype, faux_v.dtype) self.assertEqual(batch['obs'].shape, (batch_size, seq_len)) self.assertEqual(batch['target'].shape, (batch_size, seq_len)) self.assertEqual(batch['should_reset'].shape, (batch_size, seq_len)) self.assertEqual(batch['mask'].shape, (batch_size, seq_len)) num_tokens = 0 for batch in dataset: num_tokens += batch['mask'].sum() raw_dataset = paired_dataset.RawDataset(subset='valid', shuffle_data=False) raw_num_tokens = 0 n_pairs = 0 for pair in raw_dataset: raw_num_tokens += len(tokenizer.encode( pair.text, prepend_bos=True, append_eos=True)) n_pairs += 1 self.assertEqual(num_tokens + n_pairs, raw_num_tokens)
def test_graph_retrieval_dataset(self): tokenizer = tokenizers.WordTokenizer(vocab_file=WIKITEXT_VOCAB_FILE) graph_tokenizer = tokenizers.GraphTokenizer(vocab_file=GRAPH_VOCAB_FILE) batch_size = 4 seq_len = 256 dataset = paired_dataset.Graph2TextDataset( tokenizer, graph_tokenizer, batch_size=batch_size, timesteps=seq_len, subsample_nodes=0.8, graph_retrieval_dataset=True, subset='valid', data_dir=WIKIGRAPHS_ROOT) data_iter = iter(dataset) batch = next(data_iter) self.assertEqual(batch['obs'].shape, (batch_size, seq_len)) self.assertEqual(batch['target'].shape, (batch_size, seq_len)) self.assertEqual(batch['should_reset'].shape, (batch_size, seq_len)) self.assertEqual(batch['mask'].shape, (batch_size, seq_len)) self.assertEqual(batch['graph_id'].shape, (batch_size,)) self.assertEqual(batch['seq_id'].shape, (batch_size,))
def init_tokenizer(dataset_name): """Initialie the tokenizer.""" logging.info('Loading tokenizer...') tokenizer = tokenizers.WordTokenizer(VOCAB_FILES_MAP[dataset_name]) logging.info('Vocab size: %d', tokenizer.vocab_size) return tokenizer