Ejemplo n.º 1
0
    def adjustWeights(self, input, target, rate, mrate):
        self.setInput(input)
        outnodes = self.layers[2]
        if len(target) != len(outnodes):
            raise ValueError("wrong number of target values")
        deltas = {}
        #output errors
        for k in xrange(len(target)):
            error = target[k] - outnodes[k].outval
            deltas[outnodes[k]] = dfunc(outnodes[k].outval) * error
        
        for node in self.layers[1]:
            error = 0.0
            for conn in node.outputs:
                error += deltas[conn.dest] * conn.weight
            deltas[node] = dfunc(node.outval) * error

        for conn in self.tnode.outputs:
            change = deltas[conn.dest] * conn.weight * rate
            conn.weight += change + self.oldw.get(conn, 0.0) * mrate
            self.oldw[conn] = change
            conn.output = conn.weight

        #find weight changes
        for layer in self.layers:
            for node in layer:
                for conn in node.outputs:
                    if conn.dest.id != -1 and conn.src.id != -1:
                        change = deltas[conn.dest] * conn.input * rate
                        conn.weight += change + self.oldw.get(conn, 0.0) * mrate
                        self.oldw[conn] = change
        
        #find error
        error = 0.0
        for k in xrange(len(target)):
            error += 1.0/len(target) * (target[k] - outnodes[k].outval)**2.0
        return error ** 0.5
Ejemplo n.º 2
0
    def adjustWeights(self, target, rate, mrate):
        """Apply backpropagation through time on network and adjust weights

        target (list) :: the list of target outputs
        rate (float) :: learning rate
        mrate (float) :: momentum rate

        I don't think I can explain how backprop through time works here. There are
        papers and books for that. But I can explain my implementation and how it ties
        into the damned algorithm. 90% of this is standard backprop. Understanding that
        is probably a wise prequisite. I'll be mostly commenting on the "through time"
        parts. Read on!
        """
        outnodes = self.layers[2]
        if len(target) != len(outnodes):
            raise ValueError("wrong number of target values")

        # The keys are a tuple of the node and the time step for that delta (node, time). 
        # time = 0 is for the current time step. time > 0 is for back in time. It would
        # make more sense if time decreased, but it makes the programming a little easier.
        deltas = {}

        # Calculate output errors. Nothing different here.
        for k in xrange(len(target)):
            error = target[k] - outnodes[k].outval
            deltas[(outnodes[k],0)] = dfunc(outnodes[k].outval) * error
        
        # Calculate output error for hidden nodes at t = 0. Recurrent links aren't used for
        # this calculation. Except for that if statement, this is identical to regular
        # backprop. The reason that recurrent links aren't used is because it is important
        # to remember that hidden state layer and the hidden past state layer are really
        # "different" layers. The deltas at t = 0 only depend on the output layer. The 
        # state layer at t = 1 depends on the state layer at t = 0, t = 2 depends on t = 1, 
        # and so on. I hope this makes sense. BPTT is really confusing and hard to explain
        # without pictures. At least for me.
        for node in self.layers[1]:
            error = 0.0
            for conn in node.outputs:
                if conn.delay == 0:
                    error += deltas[(conn.dest, 0)] * conn.weight
            deltas[(node,0)] = dfunc(node.outval) * error

        # Now we can calculate deltas for older time steps. Because the t = 0 hidden 
        # layer relies on the output layer, unlike the t > 0 hidden layers that depend
        # on newer time hidden layers, it had to be calculated seperately. But now that we are
        # on the t > 0 hidden, we can iteratively move through the time steps. Notice that
        # this is for the most part identical to regular backprop, except now we keep moving
        # back in time (by increasing t (or i in this case)), the the dependent layer isn't
        # the output layer but hidden layer t - 1.
        
        # I'm assuming a delay of 1 here at all times. I should make it more general
        for i in xrange(self.history-1):
            for node in self.layers[1]:
                error = 0.0
                for conn in node.outputs:
                    if conn.delay == 1:
                        error += deltas[(conn.dest, i)] * conn.weight
                #I think this is right because hist[0] will be the same as node.outval
                #due to the way I store the history
                deltas[(node, i+1)] = dfunc(node.history[i+1]) * error

        # Now we can start changing weights. First the threshold nodes.
        for conn in self.tnode.outputs:
            change = deltas[(conn.dest,0)] * conn.weight * rate
            conn.weight += change
            conn.output = conn.weight

        # Now weight changes for the rest of the nodes. The weight changes for the recurrent
        # links are summed over each time step.
        for layer in self.layers:
            for node in layer:
                for conn in node.outputs:
                    if conn.dest.id != -1 and conn.src.id != -1:
                        if conn.delay == 0:
                            change = deltas[(conn.dest,0)] * conn.input * rate
                        elif conn.delay == 1:
                            change = 0.0
                            for i in xrange(self.history-1):
                                change += deltas[(conn.dest, i)] * conn.src.history[i+1] * rate
                        conn.weight += change + self.oldw.get(conn, 0.0) * mrate
                        self.oldw[conn] = change
        
        # find error - RMS
        error = 0.0
        N = len(target)
        for k in xrange(len(target)):
            error += 1.0/N * (target[k] - outnodes[k].outval)**2.0
        error = error ** 0.5
        return error