Ejemplo n.º 1
0
 def trainModel(self, dataObj, save, plot):
     #Define session
     for i in range(self.displayPeriod):
         #Get data from dataObj
         data = dataObj.getData(self.batchSize)
         feedDict = {self.input: data[0], self.gt: data[1]}
         #Run optimizer
         if (self.preTrain):
             self.sess.run(self.optimizerPre, feed_dict=feedDict)
         else:
             self.sess.run(self.optimizer, feed_dict=feedDict)
         if (i % self.writeStep == 0):
             summary = self.sess.run(self.mergedSummary, feed_dict=feedDict)
             self.train_writer.add_summary(summary, self.timestep)
         if (i % self.progress == 0):
             print "Timestep ", self.timestep
         self.timestep += 1
     if (save):
         save_path = self.saver.save(self.sess,
                                     self.saveFile,
                                     global_step=self.timestep,
                                     write_meta_graph=False)
         print("Model saved in file: %s" % save_path)
     if (plot):
         filename = self.plotDir + "weights_" + str(self.timestep) + ".png"
         np_w = self.sess.run(self.W_encode, feed_dict=feedDict)
         plot_weights(np_w, filename, order=[3, 0, 1, 2])
Ejemplo n.º 2
0
    def trainW(self):
        feedDict = {self.inputImage: self.currImg}

        #Visualization
        if (self.plotTimestep % self.plotPeriod == 0):
            np_V1_W = self.sess.run(self.weightImages)
            np_V1_A = self.sess.run(self.V1_A)


            rescaled_V1_W = np.exp(np.abs(np_V1_W * np.sqrt(self.patchSizeX * self.patchSizeY))) * np.sign(np_V1_W)
            #plot_weights(rescaled_V1_W, self.plotDir+"dict_"+str(self.timestep), activity=np_V1_A)
            plot_weights(rescaled_V1_W, self.plotDir+"dict_"+str(self.timestep))
            #plot_1d_weights(rescaled_V1_W, self.plotDir+"dict_"+str(self.timestep), activity=np_V1_A)

            np_inputImage = self.currImg
            np_recon = self.sess.run(self.recon, feed_dict=feedDict)

            #Draw recons
            rescaled_inputImage = np.exp(np.abs(np_inputImage * np.sqrt(self.patchSizeX * self.patchSizeY))) * np.sign(np_inputImage)

            rescaled_recon = np.exp(np.abs(np_recon * np.sqrt(self.patchSizeX * self.patchSizeY))) * np.sign(np_recon)

            plotRecon(np_recon, np_inputImage, self.plotDir+"recon_"+str(self.timestep), r=range(4))
            #plotRecon1d(np.squeeze(rescaled_recon), np.squeeze(rescaled_inputImage), self.plotDir+"recon_"+str(self.timestep), r=range(4))

        #Update weights
        self.sess.run(self.optimizerW, feed_dict=feedDict)
        #New image
        self.currImg = self.dataObj.getData(self.batchSize)
        self.plotTimestep += 1
Ejemplo n.º 3
0
    def trainW(self):
        feedDict = {self.inputImage: self.currImg}

        #Visualization
        if (self.plotTimestep % self.plotPeriod == 0):
            for l in range(self.numLayers):
                np_V1_W = self.sess.run(self.visWeight[l])
                plot_weights(
                    np_V1_W, self.plotDir + "dict_S" + str(l) + "_" +
                    str(self.timestep) + ".png")
                #Draw recons
                np_inputImage = self.currImg

                np_recon = self.sess.run(self.visRecon[l], feed_dict=feedDict)
                np_t_recon = self.sess.run(self.t_visRecon[l],
                                           feed_dict=feedDict)
                plotRecon(np_recon,
                          np_inputImage,
                          self.plotDir + "recon_S" + str(l) + "_" +
                          str(self.timestep) + ".png",
                          r=range(4))
                plotRecon(np_t_recon,
                          np_inputImage,
                          self.plotDir + "t_recon_S" + str(l) + "_" +
                          str(self.timestep) + ".png",
                          r=range(4))

        #Update weights
        self.sess.run(self.optimizerW, feed_dict=feedDict)
        #New image
        self.currImg = self.dataObj.getData(self.batchSize)
        self.plotTimestep += 1
Ejemplo n.º 4
0
    def plot(self):
        #Visualization
        if (self.plotTimestep % self.plotPeriod == 0):
            np_V1_W = self.sess.run(self.weightImages)
            np_V1_A = self.sess.run(self.V1_A)

            #plot_weights(rescaled_V1_W, self.plotDir+"dict_"+str(self.timestep), activity=np_V1_A)

            plotStr = self.plotDir + "dict_"+str(self.timestep)
            if(np_V1_W.ndim == 2):
                rescaled_V1_W = np.exp(np.abs(np_V1_W * np.sqrt(self.patchSizeX * self.patchSizeY))) * np.sign(np_V1_W)
                plot_1d_weights(rescaled_V1_W, plotStr, activity=np_V1_A)
            else:
                plot_weights(V1_W, plotStr)

            np_inputImage = self.currImg
            np_recon = self.sess.run(self.recon, feed_dict=feedDict)

            #Draw recons
            if(np.squeeze(np_recon).ndim == 2):
                rescaled_inputImage = np_inputImage * np.sqrt(self.patchSizeX * self.patchSizeY)
                rescaled_recon = np_recon * np.sqrt(self.patchSizeX * self.patchSizeY)

                exp_inputImage = np.squeeze(np.exp(np.abs(rescaled_inputImage) - 1e-10) * np.sign(np_inputImage))
                exp_recon = np.squeeze(np.exp(np.abs(rescaled_recon) - 1e-10) * np.sign(np_recon))

                plotRecon1d(exp_recon, exp_inputImage, self.plotDir+"recon_"+str(self.timestep), r=range(4))
            else:
                plotRecon(np_recon, np_inputImage, self.plotDir+"recon_"+str(self.timestep), r=range(4))

        self.plotTimestep += 1
Ejemplo n.º 5
0
    def trainW(self):
        feedDict = {self.inputImage: self.currImg}

        #Visualization
        if (self.plotTimestep % self.plotPeriod == 0):
            np_V1_W = self.sess.run(self.weightImages)
            plot_weights(np_V1_W, self.plotDir+"dict_"+str(self.timestep)+".png")
            #Draw recons
            np_inputImage = self.currImg
            np_recon = self.sess.run(self.recon, feed_dict=feedDict)
            plotRecon(np_recon, np_inputImage, self.plotDir+"recon_"+str(self.timestep), r=range(4))

        #Update weights
        self.sess.run(self.optimizerW, feed_dict=feedDict)
        #New image
        self.currImg = self.dataObj.getData(self.batchSize)
        self.plotTimestep += 1
Ejemplo n.º 6
0
    def trainW(self):
        feedDict = {self.inputImage: self.currImg}

        #Visualization
        if (self.plotTimestep % self.plotPeriod == 0):
            np_V1_W = self.sess.run(self.weightImages)
            plot_weights(np_V1_W, self.plotDir+"dict_"+str(self.timestep)+".png")
            #Draw recons
            np_inputImage = self.currImg
            np_recon = self.sess.run(self.recon, feed_dict=feedDict)
            plotRecon(np_recon, np_inputImage, self.plotDir+"recon_"+str(self.timestep), r=range(4))

        #Update weights
        self.sess.run(self.optimizerW, feed_dict=feedDict)
        #New image
        self.currImg = self.dataObj.getData(self.batchSize)
        self.plotTimestep += 1
Ejemplo n.º 7
0
    def plotWeight(self):
        #Make directory for timestep
        outPlotDir = self.plotDir + "/" + str(self.timestep) + "/"
        if not os.path.exists(outPlotDir):
            os.makedirs(outPlotDir)

        np_V1_W = self.sess.run(self.weightImages)
        np_V1_A = self.sess.run(self.V1_A)

        #plot_weights(rescaled_V1_W, self.plotDir+"dict_"+str(self.timestep), activity=np_V1_A)

        plotStr = outPlotDir + "dict_"
        if (np_V1_W.ndim == 3):
            plot_1d_weights(np_V1_W,
                            plotStr,
                            activity=np_V1_A,
                            sepFeatures=True)
        else:
            plot_weights(V1_W, plotStr)
Ejemplo n.º 8
0
    def trainW(self):
        feedDict = {self.inputImage: self.currImg}

        #Visualization
        if (self.plotTimestep % self.plotPeriod == 0):
            for l in range(self.numLayers):
                np_V1_W = self.sess.run(self.visWeight[l])
                plot_weights(np_V1_W, self.plotDir+"dict_S" + str(l) + "_" +str(self.timestep)+".png")
                #Draw recons
                np_inputImage = self.currImg

                np_recon = self.sess.run(self.visRecon[l], feed_dict=feedDict)
                np_t_recon = self.sess.run(self.t_visRecon[l], feed_dict=feedDict)
                plotRecon(np_recon, np_inputImage, self.plotDir+"recon_S"+str(l)+"_"+str(self.timestep)+".png", r=range(4))
                plotRecon(np_t_recon, np_inputImage, self.plotDir+"t_recon_S"+str(l)+"_"+str(self.timestep)+".png", r=range(4))

        #Update weights
        self.sess.run(self.optimizerW, feed_dict=feedDict)
        #New image
        self.currImg = self.dataObj.getData(self.batchSize)
        self.plotTimestep += 1
Ejemplo n.º 9
0
 def trainModel(self, dataObj, save, plot):
     #Define session
     for i in range(self.displayPeriod):
         #Get data from dataObj
         data = dataObj.getData(self.batchSize)
         feedDict = {self.input: data[0], self.gt: data[1]}
         #Run optimizer
         if(self.preTrain):
             self.sess.run(self.optimizerPre, feed_dict=feedDict)
         else:
             self.sess.run(self.optimizer, feed_dict=feedDict)
         if(i%self.writeStep == 0):
             summary = self.sess.run(self.mergedSummary, feed_dict=feedDict)
             self.train_writer.add_summary(summary, self.timestep)
         if(i%self.progress == 0):
             print "Timestep ", self.timestep
         self.timestep+=1
     if(save):
         save_path = self.saver.save(self.sess, self.saveFile, global_step=self.timestep, write_meta_graph=False)
         print("Model saved in file: %s" % save_path)
     if(plot):
         filename = self.plotDir + "weights_" + str(self.timestep) + ".png"
         np_w = self.sess.run(self.W_encode, feed_dict=feedDict)
         plot_weights(np_w, filename, order=[3, 0, 1, 2])