def trainModel(self, dataObj, save, plot): #Define session for i in range(self.displayPeriod): #Get data from dataObj data = dataObj.getData(self.batchSize) feedDict = {self.input: data[0], self.gt: data[1]} #Run optimizer if (self.preTrain): self.sess.run(self.optimizerPre, feed_dict=feedDict) else: self.sess.run(self.optimizer, feed_dict=feedDict) if (i % self.writeStep == 0): summary = self.sess.run(self.mergedSummary, feed_dict=feedDict) self.train_writer.add_summary(summary, self.timestep) if (i % self.progress == 0): print "Timestep ", self.timestep self.timestep += 1 if (save): save_path = self.saver.save(self.sess, self.saveFile, global_step=self.timestep, write_meta_graph=False) print("Model saved in file: %s" % save_path) if (plot): filename = self.plotDir + "weights_" + str(self.timestep) + ".png" np_w = self.sess.run(self.W_encode, feed_dict=feedDict) plot_weights(np_w, filename, order=[3, 0, 1, 2])
def trainW(self): feedDict = {self.inputImage: self.currImg} #Visualization if (self.plotTimestep % self.plotPeriod == 0): np_V1_W = self.sess.run(self.weightImages) np_V1_A = self.sess.run(self.V1_A) rescaled_V1_W = np.exp(np.abs(np_V1_W * np.sqrt(self.patchSizeX * self.patchSizeY))) * np.sign(np_V1_W) #plot_weights(rescaled_V1_W, self.plotDir+"dict_"+str(self.timestep), activity=np_V1_A) plot_weights(rescaled_V1_W, self.plotDir+"dict_"+str(self.timestep)) #plot_1d_weights(rescaled_V1_W, self.plotDir+"dict_"+str(self.timestep), activity=np_V1_A) np_inputImage = self.currImg np_recon = self.sess.run(self.recon, feed_dict=feedDict) #Draw recons rescaled_inputImage = np.exp(np.abs(np_inputImage * np.sqrt(self.patchSizeX * self.patchSizeY))) * np.sign(np_inputImage) rescaled_recon = np.exp(np.abs(np_recon * np.sqrt(self.patchSizeX * self.patchSizeY))) * np.sign(np_recon) plotRecon(np_recon, np_inputImage, self.plotDir+"recon_"+str(self.timestep), r=range(4)) #plotRecon1d(np.squeeze(rescaled_recon), np.squeeze(rescaled_inputImage), self.plotDir+"recon_"+str(self.timestep), r=range(4)) #Update weights self.sess.run(self.optimizerW, feed_dict=feedDict) #New image self.currImg = self.dataObj.getData(self.batchSize) self.plotTimestep += 1
def trainW(self): feedDict = {self.inputImage: self.currImg} #Visualization if (self.plotTimestep % self.plotPeriod == 0): for l in range(self.numLayers): np_V1_W = self.sess.run(self.visWeight[l]) plot_weights( np_V1_W, self.plotDir + "dict_S" + str(l) + "_" + str(self.timestep) + ".png") #Draw recons np_inputImage = self.currImg np_recon = self.sess.run(self.visRecon[l], feed_dict=feedDict) np_t_recon = self.sess.run(self.t_visRecon[l], feed_dict=feedDict) plotRecon(np_recon, np_inputImage, self.plotDir + "recon_S" + str(l) + "_" + str(self.timestep) + ".png", r=range(4)) plotRecon(np_t_recon, np_inputImage, self.plotDir + "t_recon_S" + str(l) + "_" + str(self.timestep) + ".png", r=range(4)) #Update weights self.sess.run(self.optimizerW, feed_dict=feedDict) #New image self.currImg = self.dataObj.getData(self.batchSize) self.plotTimestep += 1
def plot(self): #Visualization if (self.plotTimestep % self.plotPeriod == 0): np_V1_W = self.sess.run(self.weightImages) np_V1_A = self.sess.run(self.V1_A) #plot_weights(rescaled_V1_W, self.plotDir+"dict_"+str(self.timestep), activity=np_V1_A) plotStr = self.plotDir + "dict_"+str(self.timestep) if(np_V1_W.ndim == 2): rescaled_V1_W = np.exp(np.abs(np_V1_W * np.sqrt(self.patchSizeX * self.patchSizeY))) * np.sign(np_V1_W) plot_1d_weights(rescaled_V1_W, plotStr, activity=np_V1_A) else: plot_weights(V1_W, plotStr) np_inputImage = self.currImg np_recon = self.sess.run(self.recon, feed_dict=feedDict) #Draw recons if(np.squeeze(np_recon).ndim == 2): rescaled_inputImage = np_inputImage * np.sqrt(self.patchSizeX * self.patchSizeY) rescaled_recon = np_recon * np.sqrt(self.patchSizeX * self.patchSizeY) exp_inputImage = np.squeeze(np.exp(np.abs(rescaled_inputImage) - 1e-10) * np.sign(np_inputImage)) exp_recon = np.squeeze(np.exp(np.abs(rescaled_recon) - 1e-10) * np.sign(np_recon)) plotRecon1d(exp_recon, exp_inputImage, self.plotDir+"recon_"+str(self.timestep), r=range(4)) else: plotRecon(np_recon, np_inputImage, self.plotDir+"recon_"+str(self.timestep), r=range(4)) self.plotTimestep += 1
def trainW(self): feedDict = {self.inputImage: self.currImg} #Visualization if (self.plotTimestep % self.plotPeriod == 0): np_V1_W = self.sess.run(self.weightImages) plot_weights(np_V1_W, self.plotDir+"dict_"+str(self.timestep)+".png") #Draw recons np_inputImage = self.currImg np_recon = self.sess.run(self.recon, feed_dict=feedDict) plotRecon(np_recon, np_inputImage, self.plotDir+"recon_"+str(self.timestep), r=range(4)) #Update weights self.sess.run(self.optimizerW, feed_dict=feedDict) #New image self.currImg = self.dataObj.getData(self.batchSize) self.plotTimestep += 1
def plotWeight(self): #Make directory for timestep outPlotDir = self.plotDir + "/" + str(self.timestep) + "/" if not os.path.exists(outPlotDir): os.makedirs(outPlotDir) np_V1_W = self.sess.run(self.weightImages) np_V1_A = self.sess.run(self.V1_A) #plot_weights(rescaled_V1_W, self.plotDir+"dict_"+str(self.timestep), activity=np_V1_A) plotStr = outPlotDir + "dict_" if (np_V1_W.ndim == 3): plot_1d_weights(np_V1_W, plotStr, activity=np_V1_A, sepFeatures=True) else: plot_weights(V1_W, plotStr)
def trainW(self): feedDict = {self.inputImage: self.currImg} #Visualization if (self.plotTimestep % self.plotPeriod == 0): for l in range(self.numLayers): np_V1_W = self.sess.run(self.visWeight[l]) plot_weights(np_V1_W, self.plotDir+"dict_S" + str(l) + "_" +str(self.timestep)+".png") #Draw recons np_inputImage = self.currImg np_recon = self.sess.run(self.visRecon[l], feed_dict=feedDict) np_t_recon = self.sess.run(self.t_visRecon[l], feed_dict=feedDict) plotRecon(np_recon, np_inputImage, self.plotDir+"recon_S"+str(l)+"_"+str(self.timestep)+".png", r=range(4)) plotRecon(np_t_recon, np_inputImage, self.plotDir+"t_recon_S"+str(l)+"_"+str(self.timestep)+".png", r=range(4)) #Update weights self.sess.run(self.optimizerW, feed_dict=feedDict) #New image self.currImg = self.dataObj.getData(self.batchSize) self.plotTimestep += 1
def trainModel(self, dataObj, save, plot): #Define session for i in range(self.displayPeriod): #Get data from dataObj data = dataObj.getData(self.batchSize) feedDict = {self.input: data[0], self.gt: data[1]} #Run optimizer if(self.preTrain): self.sess.run(self.optimizerPre, feed_dict=feedDict) else: self.sess.run(self.optimizer, feed_dict=feedDict) if(i%self.writeStep == 0): summary = self.sess.run(self.mergedSummary, feed_dict=feedDict) self.train_writer.add_summary(summary, self.timestep) if(i%self.progress == 0): print "Timestep ", self.timestep self.timestep+=1 if(save): save_path = self.saver.save(self.sess, self.saveFile, global_step=self.timestep, write_meta_graph=False) print("Model saved in file: %s" % save_path) if(plot): filename = self.plotDir + "weights_" + str(self.timestep) + ".png" np_w = self.sess.run(self.W_encode, feed_dict=feedDict) plot_weights(np_w, filename, order=[3, 0, 1, 2])