Example #1
0
class CommentNetwork:
    def __init__(self, n, saveFile, opt, lossFunc, mod=None, use_gpu=False, numDirectIterations=1, defaultOutputTruncation=10):
        self.n=n
        self.saveFile=saveFile
        self.use_gpu=use_gpu
        self.lossFunc=lossFunc
        self.numDirectIterations=numDirectIterations
        self.defaultOutputTruncation=defaultOutputTruncation

        if mod==None:
            #construct network model
            self.model= FunctionSet(
                x_to_h = F.Linear(7, n),
                h_to_h = F.Linear(n, n),
                h_to_y = F.Linear(n, 7)
            )
        else:
            self.model=mod

        if self.use_gpu:
            self.model.to_gpu()
        else:
            self.model.to_cpu()

        self.optimizer = opt
        self.optimizer.setup(self.model)

        #constants
        self.null_byte=np.array([[0]*7], dtype=np.float32)
        if self.use_gpu:
            self.null_byte=cuda.to_gpu(self.null_byte)
        self.null_byte=Variable(self.null_byte)

    def forward_one_step(self, h, x, computeOutput=True):
        h=F.sigmoid(self.model.x_to_h(x) + self.model.h_to_h(h))
        if computeOutput:
            y=F.sigmoid(self.model.h_to_y(h))
            return h, y
        else:
            return h

    def forward(self, input_string, output_string, truncateSize=None, volatile=False):
        if truncateSize==None:
            truncateSize=self.defaultOutputTruncation

        #feed variable in, ignoring output until model has whole input string
        h=np.zeros((1,self.n),dtype=np.float32)
        if self.use_gpu:
            h=cuda.to_gpu(h)
        h=Variable(h, volatile=volatile)
        for c in input_string:
            bits=np.array([[bool(ord(c)&(2**i)) for i in range(7)]], dtype=np.float32)
            if self.use_gpu:
                bits=cuda.to_gpu(bits)
            bits=Variable(bits, volatile=volatile) #8 bits, never all 0 for ascii
            h=self.forward_one_step(h, bits, computeOutput=False)

        #prep for training
        self.optimizer.zero_grads()
        y='' #output string
        nullEnd=False
        loss=0

        def yc_translation(yc, y, nullEnd, truncateSize):
            yc=sum([bool(round(bit))*(2**i_bit) for i_bit, bit in enumerate(cuda.to_cpu(yc.data[0]))]) #translate to int
            if not yc: #null byte signifies end of sequence
                nullEnd=True
            if not nullEnd:
                y+=chr(yc) #translate to character
                truncateSize-=1
            return y, nullEnd, truncateSize

        #Read output by prompting with null bytes.; train with training output
        for c in output_string:
            bits=np.array([[bool(ord(c)&(2**i)) for i in range(7)]], dtype=np.float32)
            if self.use_gpu:
                bits=cuda.to_gpu(bits)
            bits=Variable(bits, volatile=volatile)
            h, yc = self.forward_one_step(h, self.null_byte)
            loss+=self.lossFunc(yc, bits)
            y, nullEnd, truncateSize = yc_translation(yc, y, nullEnd, truncateSize)

        #reinforce null byte as end of sequence
        h, yc = self.forward_one_step(h, self.null_byte) 
        loss+=self.lossFunc(yc, self.null_byte)
        y, nullEnd, truncateSize = yc_translation(yc, y, nullEnd, truncateSize)

        #continue reading out as long as network does not terminate and we have not hit TruncateSize
        while not nullEnd and truncateSize>0:
            h, yc = self.forward_one_step(h, self.null_byte)
            y, nullEnd, truncateSize = yc_translation(yc, y, nullEnd, truncateSize)

        #Train
        loss.backward()
        self.optimizer.update()
        return y, nullEnd #nullEnd true if netowrk terminated output sequence. False if output sequence truncated.

    def trainTree(self, tree, maxCommentLength=float('inf')): #DFS training
        if 'children' in tree:
            allPass=True
            for child in tree['children']:
                self.trainTree(child, maxCommentLength)
                prompt=tree['body']
                trainResponse=child['body']
                if prompt!='[deleted]' and trainResponse!='[deleted]' and prompt and trainResponse and len(prompt)<=maxCommentLength and len(trainResponse)<=maxCommentLength:
                    for i in range(self.numDirectIterations):
                        givenResponse, nullEnd=self.forward(prompt, trainResponse)
                        print '<#'+str(i)+'--prompt--'+str(len(prompt))+'chars-->\n', repr(prompt), '\n<--trainResponse--'+str(len(trainResponse))+'chars-->\n', repr(trainResponse), '\n<--givenResponse--'+str(len(givenResponse))+'chars'+('' if nullEnd else ', truncated')+'-->\n', repr(givenResponse)+'\n'
                        if givenResponse==trainResponse:
                            break
                        else:
                            allPass=False
            return allPass

    # loop over lines in a file identifying if they contain a tree after parsing the json
    def trainFile(self, openFile, maxCommentLength=float('inf')):
        allPass=True
        for i, treeText in enumerate(openFile):
            #throw away whitespace
            if treeText.strip():
                #print fileName, treeText
                tree=json.loads(treeText.strip())
                #it's a tree, let's train
                if 'children' in tree:
                    print 'training #'+str(i)+' '+openFile.name
                    allPass&=self.trainTree(tree, maxCommentLength)
        return allPass

    def saveModel(self):
        print 'Stopped computation, saving model. Please wait...'
        f=open(self.saveFile,'w')
        pickle.dump(self.model, f)
        f.close()
        print 'Saved model'

    def sig_exit(self, _1, _2):
        self.saveModel()
        exit()