-
Notifications
You must be signed in to change notification settings - Fork 0
/
words.py
108 lines (82 loc) · 3.06 KB
/
words.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import numpy
from chainer.dataset import download
def get_ptb_words():
"""Gets the Penn Tree Bank dataset as long word sequences.
`Penn Tree Bank <https://www.cis.upenn.edu/~treebank/>`_ is originally a
corpus of English sentences with linguistic structure annotations. This
function uses a variant distributed at
`https://github.com/tomsercu/lstm <https://github.com/tomsercu/lstm>`_,
which omits the annotation and splits the dataset into three parts:
training, validation, and test.
This function returns the training, validation, and test sets, each of
which is represented as a long array of word IDs. All sentences in the
dataset are concatenated by End-of-Sentence mark '<eos>', which is treated
as one of the vocabulary.
Returns:
tuple of numpy.ndarray: Int32 vectors of word IDs.
.. Seealso::
Use :func:`get_ptb_words_vocabulary` to get the mapping between the
words and word IDs.
"""
train = _retrieve_ptb_words('train.npz', _train_url)
valid = _retrieve_ptb_words('valid.npz', _valid_url)
test = _retrieve_ptb_words('test.npz', _test_url)
return train, valid, test
def get_ptb_words_vocabulary():
"""Gets the Penn Tree Bank word vocabulary.
Returns:
dict: Dictionary that maps words to corresponding word IDs. The IDs are
used in the Penn Tree Bank long sequence datasets.
.. seealso::
See :func:`get_ptb_words` for the actual datasets.
"""
return _retrieve_word_vocabulary()
raw_txt = 'txt/words.txt'
_train_url = raw_txt
_valid_url = raw_txt
_test_url = raw_txt
def _retrieve_ptb_words(name, url):
def creator(path):
vocab = _retrieve_word_vocabulary()
words = _load_words(url)
x = numpy.empty(len(words), dtype=numpy.int32)
for i, word in enumerate(words):
x[i] = vocab[word]
numpy.savez_compressed(path, x=x)
return {'x': x}
root = download.get_dataset_directory('txt')
path = os.path.join(root, name)
loaded = download.cache_or_load_file(path, creator, numpy.load)
return loaded['x']
def _retrieve_word_vocabulary():
def creator(path):
words = _load_words(_train_url)
vocab = {}
index = 0
with open(path, 'w') as f:
for word in words:
if word not in vocab:
vocab[word] = index
index += 1
f.write(word + '\n')
return vocab
def loader(path):
vocab = {}
with open(path) as f:
for i, word in enumerate(f):
vocab[word.strip()] = i
return vocab
root = download.get_dataset_directory('txt')
path = os.path.join(root, 'vocab.txt')
print root
return download.cache_or_load_file(path, creator, loader)
def _load_words(url):
path = download.cached_download(url)
words = []
with open(path) as words_file:
for line in words_file:
if line:
words += line.strip().split()
words.append('<eos>')
return words