예제 #1
0
    def trainW(self):
        feedDict = {self.inputImage: self.currImg}

        #Visualization
        if (self.plotTimestep % self.plotPeriod == 0):
            for l in range(self.numLayers):
                np_V1_W = self.sess.run(self.visWeight[l])
                plot_weights(
                    np_V1_W, self.plotDir + "dict_S" + str(l) + "_" +
                    str(self.timestep) + ".png")
                #Draw recons
                np_inputImage = self.currImg

                np_recon = self.sess.run(self.visRecon[l], feed_dict=feedDict)
                np_t_recon = self.sess.run(self.t_visRecon[l],
                                           feed_dict=feedDict)
                plotRecon(np_recon,
                          np_inputImage,
                          self.plotDir + "recon_S" + str(l) + "_" +
                          str(self.timestep) + ".png",
                          r=range(4))
                plotRecon(np_t_recon,
                          np_inputImage,
                          self.plotDir + "t_recon_S" + str(l) + "_" +
                          str(self.timestep) + ".png",
                          r=range(4))

        #Update weights
        self.sess.run(self.optimizerW, feed_dict=feedDict)
        #New image
        self.currImg = self.dataObj.getData(self.batchSize)
        self.plotTimestep += 1
예제 #2
0
    def trainW(self):
        feedDict = {self.inputImage: self.currImg}

        #Visualization
        if (self.plotTimestep % self.plotPeriod == 0):
            np_V1_W = self.sess.run(self.weightImages)
            np_V1_A = self.sess.run(self.V1_A)


            rescaled_V1_W = np.exp(np.abs(np_V1_W * np.sqrt(self.patchSizeX * self.patchSizeY))) * np.sign(np_V1_W)
            #plot_weights(rescaled_V1_W, self.plotDir+"dict_"+str(self.timestep), activity=np_V1_A)
            plot_weights(rescaled_V1_W, self.plotDir+"dict_"+str(self.timestep))
            #plot_1d_weights(rescaled_V1_W, self.plotDir+"dict_"+str(self.timestep), activity=np_V1_A)

            np_inputImage = self.currImg
            np_recon = self.sess.run(self.recon, feed_dict=feedDict)

            #Draw recons
            rescaled_inputImage = np.exp(np.abs(np_inputImage * np.sqrt(self.patchSizeX * self.patchSizeY))) * np.sign(np_inputImage)

            rescaled_recon = np.exp(np.abs(np_recon * np.sqrt(self.patchSizeX * self.patchSizeY))) * np.sign(np_recon)

            plotRecon(np_recon, np_inputImage, self.plotDir+"recon_"+str(self.timestep), r=range(4))
            #plotRecon1d(np.squeeze(rescaled_recon), np.squeeze(rescaled_inputImage), self.plotDir+"recon_"+str(self.timestep), r=range(4))

        #Update weights
        self.sess.run(self.optimizerW, feed_dict=feedDict)
        #New image
        self.currImg = self.dataObj.getData(self.batchSize)
        self.plotTimestep += 1
예제 #3
0
    def plot(self):
        #Visualization
        if (self.plotTimestep % self.plotPeriod == 0):
            np_V1_W = self.sess.run(self.weightImages)
            np_V1_A = self.sess.run(self.V1_A)

            #plot_weights(rescaled_V1_W, self.plotDir+"dict_"+str(self.timestep), activity=np_V1_A)

            plotStr = self.plotDir + "dict_"+str(self.timestep)
            if(np_V1_W.ndim == 2):
                rescaled_V1_W = np.exp(np.abs(np_V1_W * np.sqrt(self.patchSizeX * self.patchSizeY))) * np.sign(np_V1_W)
                plot_1d_weights(rescaled_V1_W, plotStr, activity=np_V1_A)
            else:
                plot_weights(V1_W, plotStr)

            np_inputImage = self.currImg
            np_recon = self.sess.run(self.recon, feed_dict=feedDict)

            #Draw recons
            if(np.squeeze(np_recon).ndim == 2):
                rescaled_inputImage = np_inputImage * np.sqrt(self.patchSizeX * self.patchSizeY)
                rescaled_recon = np_recon * np.sqrt(self.patchSizeX * self.patchSizeY)

                exp_inputImage = np.squeeze(np.exp(np.abs(rescaled_inputImage) - 1e-10) * np.sign(np_inputImage))
                exp_recon = np.squeeze(np.exp(np.abs(rescaled_recon) - 1e-10) * np.sign(np_recon))

                plotRecon1d(exp_recon, exp_inputImage, self.plotDir+"recon_"+str(self.timestep), r=range(4))
            else:
                plotRecon(np_recon, np_inputImage, self.plotDir+"recon_"+str(self.timestep), r=range(4))

        self.plotTimestep += 1
예제 #4
0
    def plotRecon(self):
        #Visualization
        #if (self.plotTimestep % self.plotPeriod == 0):

        #Make directory for timestep
        outPlotDir = self.plotDir + "/" + str(self.timestep) + "/"
        if not os.path.exists(outPlotDir):
            os.makedirs(outPlotDir)

        np_inputImage = self.currImg
        feedDict = {self.inputImage: self.currImg}
        np_recon = np.squeeze(self.sess.run(self.recon, feed_dict=feedDict))

        #Draw recons
        plotStr = outPlotDir + "recon_"
        if (np_recon.ndim == 3):
            rescaled_inputImage = np.squeeze(
                self.sess.run(self.scaled_inputImage, feed_dict=feedDict))
            numRecon = np.minimum(self.batchSize, 4)
            plotRecon1d(np_recon,
                        rescaled_inputImage,
                        plotStr,
                        r=range(numRecon))
        else:
            plotRecon(np_recon, np_inputImage, plotStr, r=range(4))
예제 #5
0
    def trainW(self):
        feedDict = {self.inputImage: self.currImg}

        #Visualization
        if (self.plotTimestep % self.plotPeriod == 0):
            np_V1_W = self.sess.run(self.weightImages)
            plot_weights(np_V1_W, self.plotDir+"dict_"+str(self.timestep)+".png")
            #Draw recons
            np_inputImage = self.currImg
            np_recon = self.sess.run(self.recon, feed_dict=feedDict)
            plotRecon(np_recon, np_inputImage, self.plotDir+"recon_"+str(self.timestep), r=range(4))

        #Update weights
        self.sess.run(self.optimizerW, feed_dict=feedDict)
        #New image
        self.currImg = self.dataObj.getData(self.batchSize)
        self.plotTimestep += 1
예제 #6
0
파일: ista.py 프로젝트: wen036/TFSparseCode
    def trainW(self):
        feedDict = {self.inputImage: self.currImg}

        #Visualization
        if (self.plotTimestep % self.plotPeriod == 0):
            np_V1_W = self.sess.run(self.weightImages)
            plot_weights(np_V1_W, self.plotDir+"dict_"+str(self.timestep)+".png")
            #Draw recons
            np_inputImage = self.currImg
            np_recon = self.sess.run(self.recon, feed_dict=feedDict)
            plotRecon(np_recon, np_inputImage, self.plotDir+"recon_"+str(self.timestep), r=range(4))

        #Update weights
        self.sess.run(self.optimizerW, feed_dict=feedDict)
        #New image
        self.currImg = self.dataObj.getData(self.batchSize)
        self.plotTimestep += 1
예제 #7
0
    def trainW(self):
        feedDict = {self.inputImage: self.currImg}

        #Visualization
        if (self.plotTimestep % self.plotPeriod == 0):
            for l in range(self.numLayers):
                np_V1_W = self.sess.run(self.visWeight[l])
                plot_weights(np_V1_W, self.plotDir+"dict_S" + str(l) + "_" +str(self.timestep)+".png")
                #Draw recons
                np_inputImage = self.currImg

                np_recon = self.sess.run(self.visRecon[l], feed_dict=feedDict)
                np_t_recon = self.sess.run(self.t_visRecon[l], feed_dict=feedDict)
                plotRecon(np_recon, np_inputImage, self.plotDir+"recon_S"+str(l)+"_"+str(self.timestep)+".png", r=range(4))
                plotRecon(np_t_recon, np_inputImage, self.plotDir+"t_recon_S"+str(l)+"_"+str(self.timestep)+".png", r=range(4))

        #Update weights
        self.sess.run(self.optimizerW, feed_dict=feedDict)
        #New image
        self.currImg = self.dataObj.getData(self.batchSize)
        self.plotTimestep += 1