forked from keyua-cisco/theano-nlp
/
imdb_dataset.py
40 lines (31 loc) · 1.01 KB
/
imdb_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import cPickle
import gzip
from data_utils import get_file
import random
def load_data(path="data/imdb.pkl", n_words=100000, maxlen=None, test_split=0.2, seed=113):
path = get_file(path, origin="https://s3.amazonaws.com/text-datasets/imdb.pkl")
if path.endswith(".gz"):
f = gzip.open(path, 'rb')
else:
f = open(path, 'rb')
X, labels = cPickle.load(f)
random.seed(seed)
random.shuffle(X)
random.seed(seed)
random.shuffle(labels)
f.close()
if maxlen:
new_X = []
new_labels = []
for x, y in zip(X, labels):
if len(x) < maxlen:
new_X.append(x)
new_labels.append(y)
X = new_X
labels = new_labels
X = [[1 if w >= n_words else w for w in x] for x in X]
X_train = X[:int(len(X)*(1-test_split))]
y_train = labels[:int(len(X)*(1-test_split))]
X_test = X[int(len(X)*(1-test_split)):]
y_test = labels[int(len(X)*(1-test_split)):]
return (X_train, y_train), (X_test, y_test)