-
Notifications
You must be signed in to change notification settings - Fork 1
/
trainer.py
82 lines (72 loc) · 2.86 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import numpy
from pybrain.tools.shortcuts import buildNetwork
from pybrain.tools.customxml import NetworkWriter, NetworkReader
from pybrain.datasets import SequentialDataSet
from pybrain.supervised.trainers import RPropMinusTrainer
from pybrain.structure.modules import LSTMLayer
import pickle
import sys
import os
import glob
import wave_gen
def saveNetwork(filename, net):
fileObject = open(filename, 'w')
pickle.dump(net, fileObject)
fileObject.close()
def loadNetwork(filename):
fileObject = open(filename, 'r')
return pickle.load(fileObject)
def trainNetwork(dirname):
numFeatures = 5000
ds = SequentialDataSet(numFeatures, 1)
tracks = glob.glob(os.path.join(dirname, 'train??.wav'))
for t in tracks:
track = os.path.splitext(t)[0]
# load training data
print "Reading %s..." % track
data = numpy.genfromtxt(track + '_seg.csv', delimiter=",")
labels = numpy.genfromtxt(track + 'REF.txt', delimiter='\t')[0::10,1]
numData = data.shape[0]
# add the input to the dataset
print "Adding to dataset..."
ds.newSequence()
for i in range(numData):
ds.addSample(data[i], (labels[i],))
# initialize the neural network
print "Initializing neural network..."
net = buildNetwork(numFeatures, 50, 1,
hiddenclass=LSTMLayer, outputbias=False, recurrent=True)
# train the network on the dataset
print "Training neural net"
trainer = RPropMinusTrainer(net, dataset=ds)
## trainer.trainUntilConvergence(maxEpochs=50, verbose=True, validationProportion=0.1)
error = -1
for i in range(100):
new_error = trainer.train()
print "error: " + str(new_error)
if abs(error - new_error) < 0.1: break
error = new_error
# save the network
print "Saving neural network..."
NetworkWriter.writeToFile(net, os.path.basename(dirname) + 'net')
if __name__ == '__main__':
dirname = os.path.normpath(sys.argv[1])
# wave_reader.extractFeatures(track)
trainNetwork(dirname)
net = NetworkReader.readFrom(os.path.basename(dirname) + 'net')
# predict on some of the training examples
print "Predicting on training set"
data = numpy.genfromtxt(os.path.join(dirname, 'train09_seg.csv'), delimiter=",")
labels = numpy.genfromtxt(os.path.join(dirname, 'train09REF.txt'), delimiter='\t')[0::10,1]
## for i in range(200):
## print net.activate(data[i]), labels[i]
cdata = numpy.array([])
for feature in data:
freq = max(0, net.activate(feature))
sample = wave_gen.saw(freq, 0.1, 44100)
cdata = numpy.concatenate([cdata, sample])
wave_gen.saveAudioBuffer('test.wav', cdata)
## for freq in labels:
## sample = wave_gen.saw(freq, 0.1, 44100)
## cdata = numpy.concatenate([cdata, sample])
## wave_gen.saveAudioBuffer('test_ref.wav', cdata)