class CommentNetwork: def __init__(self, n, saveFile, opt, lossFunc, mod=None, use_gpu=False, numDirectIterations=1, defaultOutputTruncation=10): self.n=n self.saveFile=saveFile self.use_gpu=use_gpu self.lossFunc=lossFunc self.numDirectIterations=numDirectIterations self.defaultOutputTruncation=defaultOutputTruncation if mod==None: #construct network model self.model= FunctionSet( x_to_h = F.Linear(7, n), h_to_h = F.Linear(n, n), h_to_y = F.Linear(n, 7) ) else: self.model=mod if self.use_gpu: self.model.to_gpu() else: self.model.to_cpu() self.optimizer = opt self.optimizer.setup(self.model) #constants self.null_byte=np.array([[0]*7], dtype=np.float32) if self.use_gpu: self.null_byte=cuda.to_gpu(self.null_byte) self.null_byte=Variable(self.null_byte) def forward_one_step(self, h, x, computeOutput=True): h=F.sigmoid(self.model.x_to_h(x) + self.model.h_to_h(h)) if computeOutput: y=F.sigmoid(self.model.h_to_y(h)) return h, y else: return h def forward(self, input_string, output_string, truncateSize=None, volatile=False): if truncateSize==None: truncateSize=self.defaultOutputTruncation #feed variable in, ignoring output until model has whole input string h=np.zeros((1,self.n),dtype=np.float32) if self.use_gpu: h=cuda.to_gpu(h) h=Variable(h, volatile=volatile) for c in input_string: bits=np.array([[bool(ord(c)&(2**i)) for i in range(7)]], dtype=np.float32) if self.use_gpu: bits=cuda.to_gpu(bits) bits=Variable(bits, volatile=volatile) #8 bits, never all 0 for ascii h=self.forward_one_step(h, bits, computeOutput=False) #prep for training self.optimizer.zero_grads() y='' #output string nullEnd=False loss=0 def yc_translation(yc, y, nullEnd, truncateSize): yc=sum([bool(round(bit))*(2**i_bit) for i_bit, bit in enumerate(cuda.to_cpu(yc.data[0]))]) #translate to int if not yc: #null byte signifies end of sequence nullEnd=True if not nullEnd: y+=chr(yc) #translate to character truncateSize-=1 return y, nullEnd, truncateSize #Read output by prompting with null bytes.; train with training output for c in output_string: bits=np.array([[bool(ord(c)&(2**i)) for i in range(7)]], dtype=np.float32) if self.use_gpu: bits=cuda.to_gpu(bits) bits=Variable(bits, volatile=volatile) h, yc = self.forward_one_step(h, self.null_byte) loss+=self.lossFunc(yc, bits) y, nullEnd, truncateSize = yc_translation(yc, y, nullEnd, truncateSize) #reinforce null byte as end of sequence h, yc = self.forward_one_step(h, self.null_byte) loss+=self.lossFunc(yc, self.null_byte) y, nullEnd, truncateSize = yc_translation(yc, y, nullEnd, truncateSize) #continue reading out as long as network does not terminate and we have not hit TruncateSize while not nullEnd and truncateSize>0: h, yc = self.forward_one_step(h, self.null_byte) y, nullEnd, truncateSize = yc_translation(yc, y, nullEnd, truncateSize) #Train loss.backward() self.optimizer.update() return y, nullEnd #nullEnd true if netowrk terminated output sequence. False if output sequence truncated. def trainTree(self, tree, maxCommentLength=float('inf')): #DFS training if 'children' in tree: allPass=True for child in tree['children']: self.trainTree(child, maxCommentLength) prompt=tree['body'] trainResponse=child['body'] if prompt!='[deleted]' and trainResponse!='[deleted]' and prompt and trainResponse and len(prompt)<=maxCommentLength and len(trainResponse)<=maxCommentLength: for i in range(self.numDirectIterations): givenResponse, nullEnd=self.forward(prompt, trainResponse) print '<#'+str(i)+'--prompt--'+str(len(prompt))+'chars-->\n', repr(prompt), '\n<--trainResponse--'+str(len(trainResponse))+'chars-->\n', repr(trainResponse), '\n<--givenResponse--'+str(len(givenResponse))+'chars'+('' if nullEnd else ', truncated')+'-->\n', repr(givenResponse)+'\n' if givenResponse==trainResponse: break else: allPass=False return allPass # loop over lines in a file identifying if they contain a tree after parsing the json def trainFile(self, openFile, maxCommentLength=float('inf')): allPass=True for i, treeText in enumerate(openFile): #throw away whitespace if treeText.strip(): #print fileName, treeText tree=json.loads(treeText.strip()) #it's a tree, let's train if 'children' in tree: print 'training #'+str(i)+' '+openFile.name allPass&=self.trainTree(tree, maxCommentLength) return allPass def saveModel(self): print 'Stopped computation, saving model. Please wait...' f=open(self.saveFile,'w') pickle.dump(self.model, f) f.close() print 'Saved model' def sig_exit(self, _1, _2): self.saveModel() exit()