コード例 #1
0
class simple_rnn_classifier():
  def __init__(self, ds): 
    self.alldata = SequentialDataSet(ds.num_features, 1)
    # Now add the samples to the data set.
    idx = 1
    self.alldata.newSequence()
    for sample in ds.all_moves:
      self.alldata.addSample(sample.get_features(), [ds.get_classes().index(sample.class_)])
      idx += 1
      if (idx%6 == 0): 
        self.alldata.newSequence()
     

    self.tstdata, self.trndata = self.alldata.splitWithProportion(0.25)
    #print "Number of training patterns: ", len(self.trndata)
    #print "Input and output dimensions: ", self.trndata.indim, self.trndata.outdim
    #print "First sample (input, target, class):"
    #print self.trndata['input'][0], self.trndata['target'][0], self.trndata['class'][0]
    # 5 hidden layers.  
    self.rnn = buildNetwork(self.trndata.indim,
                            3, 
                            self.trndata.outdim,  
                            hiddenclass=LSTMLayer, 
                            outclass=SigmoidLayer, 
                            recurrent=True)
 
    self.rnn.randomize()
    self.trainer = BackpropTrainer(self.nn, dataset=self.trndata, momentum=0.1, verbose=True, weightdecay=0.01)

      
  def start_training(self):
    f = open("./results/rnn_perf.txt", "w");
    for i in range(200):
      print "training step: " , i
      self.trainer.trainEpochs(1)
      err = self.evaluate()
      f.write(str(err) + ",")
      f.flush()
    f.close()

  def evaluate(self):
    print "epoch:" , self.trainer.totalepochs
    correct = 0
    wrong = 0
    self.fnn.sortModules()
    for Idx in range (len(self.tstdata)):
      out = self.fnn.activate(self.tstdata['input'][Idx])
      if argmax(out) == argmax(self.tstdata['target'][Idx]) : 
        correct += 1
      else:
        wrong += 1 

    correct_ratio = correct*1.0/(wrong + correct)    
    self.correct_perc.append(correct_ratio)

    print "Wrong Predictions: "  , wrong ,   "Ratio = ", wrong*100.0/(wrong+correct) , "%"
    print "Correct Predictions: ", correct,  "Ratio = ", correct*100.0/(wrong+correct) , "%"
    if (self.max_ratio < correct_ratio): 
      print "Found new max, saving network"
      self.write_out("best_perfrming_")
      self.max_ratio = correct_ratio

    return 1 - correct_ratio
 
  def write_out(self, name=""):
    NetworkWriter.writeToFile(self.fnn,  "./results/" + name + "rnn.xml")
コード例 #2
0
ファイル: RNN.py プロジェクト: ranBernstein/GaitKinect
def evalSensor(sensorIndex, featureSpaceIndices):
    #reset data structure
    allByTime = {}
    f = open(fileName, 'r')
    headers = f.readline().split()
    for line in f:
        splited = line.split() 
        timeStamp = int(splited[0])
        allByTime[timeStamp] = {}
    f.close()
    allByJoint = {}
    for inputIndices in featureSpaceIndices:
        allByJoint[inputIndices] = {}
    clfs = {}
    grades = {}
    for inputIndices in featureSpaceIndices:
        allByTime, allByJoint = angleExtraction.prepareAnglesFromInput(fileName, inputIndices, sensorIndex, True, allByTime, allByJoint)
    
    #normalizing  allByJoint
    timeSet = Set([])
    for inputIndices in featureSpaceIndices:
        vec = []
        for timeStamp in allByTime.keys():
            if(timeStamp in allByJoint[inputIndices].keys()):
                timeSet.add(timeStamp)
                x = allByJoint[inputIndices][timeStamp]
                vec.append(x)
        if(len(vec) > 0):
            vec = angleExtraction.normelizeVector(vec)
        i=0
        for timeStamp in  allByTime.keys():
            if(timeStamp in allByJoint[inputIndices].keys()):
                allByJoint[inputIndices][timeStamp] = vec[i]
                i = i + 1
    
    #time set to list, output dict to list 
    time = []
    for timeStamp in timeSet:
        time.append(timeStamp)
    time.sort()
    allOutput = []
    tmpTime = []
    #clean zeros, create time ordered output vector
    for timeStamp in time:
        out = allByTime[timeStamp]['output']
        if(out != 0 and len(allByTime[timeStamp]) == featureNum + 1):
            tmpTime.append(timeStamp)
            allOutput.append(out)
    time = tmpTime 

    #normalize allOutput
    allOutput = normalByPercentile(allOutput)
    #create a net
    hiddenSize = (featureNum + 2)/2
    net = buildNetwork(featureNum, hiddenSize, 1, hiddenclass=LSTMLayer, outclass=SigmoidLayer, recurrent=True, bias=True) 
    #build dataset
    ds = SequentialDataSet(featureNum, 1)
    i=0
    lastTimeStamp = time[0]
    for timeStamp in time:
        if(len(allByTime[timeStamp]) == featureNum+1):#it is a full vector
            if(timeStamp - lastTimeStamp > 100):
                ds.newSequence()
            sample = []
            for inputIndices in featureSpaceIndices:
                sample.append(allByTime[timeStamp][inputIndices])
            ds.appendLinked(sample, allOutput[i])
        i = i + 1
        lastTimeStamp = timeStamp
    #train
    net.randomize()
    tstdata, trndata = ds.splitWithProportion( 0.25 )
    trainer = BackpropTrainer(net, trndata)
    print len(ds)
    min = 100
    trainNum = 100
    bestTrainer = None
    for i in range(trainNum):
        res = trainer.train()
        if(res < min):
            min = res
            bestTrainer = trainer
        net.randomize()
    print min
    """
    res = 100
    while(res > min):
        net.randomize()
        res = trainer.train()
    """
    trainer = bestTrainer
    for i in range(trainNum):
        res = trainer.train()
        if(i % (trainNum/10) == 0):
            print res
    print 'trndata.evaluateModuleMSE ' + str(trndata.evaluateModuleMSE(net))
    print 'tstdata.evaluateModuleMSE ' + str(tstdata.evaluateModuleMSE(net))
    #print net.activateOnDataset(tstdata)
    hits = 0.0
    total = 0.0
    #res = net.activate(tstdata)
    for i in range(trndata.getNumSequences()):
        for input, target in trndata.getSequenceIterator(i):
            res = net.activate(input)
            total += 1
            if(res[0]*target[0] > 0):
                hits+=1        
    grade = hits / total
    print 'grade ' + str(grade)
    print 'total ' + str(total)