-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
58 lines (42 loc) · 1.47 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import sys
import operator
import util
import numpy as np
def shuffle(a, b):
start_time = util.now()
print("[Shuffling...]", )
sys.stdout.flush()
assert len(a) == len(b), "Length of arrays is not equal."
combined = np.asarray([[x,y] for x,y in zip(a, b)])
np.random.shuffle(combined)
print("[Took %d milliseconds]" % (util.now() - start_time))
return combined[:, 0], combined[:, 1]
def divide(a, b, ratio):
start_time = util.now()
print("[Dividing len %d data by %f ratio...]" % (len(a), ratio))
sys.stdout.flush()
assert len(a) == len(b), "Length of arrays is not equal."
index = round(len(a) * ratio)
a_1, a_2 = a[:index], a[index:]
b_1, b_2 = b[:index], b[index:]
print("[Division of %d:%d]" % (len(a_1), len(a_2)),)
print("[Took %d milliseconds]" % (util.now() - start_time))
return a_1, a_2, b_1, b_2
def npz_load(inp, name):
data = util.loadFile(inp)
start_time = util.now()
print("[Loading %s:%s...]" % (inp, name),)
sys.stdout.flush()
loaded = data[name]
print("[Took %d milliseconds]" % (util.now() - start_time))
return loaded
def load_pairs(inp):
pairs = npz_load(inp, "pairs")
data_x, data_y = shuffle(pairs[:, 0], pairs[:, 1])
return divide(data_x, data_y, 0.995)
def load_mappings(inp):
index2token = npz_load(inp, "index2token")
token2index = {}
for i,e in enumerate(index2token):
token2index[e] = i
return token2index, index2token