コード例 #1
0
    def trainClassifier(self):
        print '-------------------------------'
        outputPrefix = self.readField(self.config, self.name,
                                      "output_directory")
        outputDir = os.path.join(outputPrefix, self.name)
        if not os.path.exists(outputDir):
            os.mkdir(outputDir)
        maxEpoch = int(self.readField(self.config, self.name, "max_pt_epoch"))
        trainSize = int(
            self.readField(self.config, self.name, "classifier_train_size"))
        numBatch = int(trainSize / (self.batchsize))

        #         self.jsae.addAE(pretrain='mse')
        trainData = []
        valData = []
        testData = []
        trainDH = []
        output = None

        for i in xrange(self.modalsCnt):
            n = self.names[i]
            s = self.saes[i]

            t = self.readField(s.ae[1].config, s.ae[1].name, "train_data")
            trainData.append(gp.garray(np.load(t)))

            t = self.readField(s.ae[1].config, s.ae[1].name, "validation_data")
            valData.append(gp.garray(np.load(t)))

            t = self.readField(s.ae[1].config, s.ae[1].name, "train_data")
            trainDH.append(
                DataHandler(t, output, s.ae[1].vDim, s.ae[-1].hDim,
                            self.batchsize, numBatch))

        t = self.readField(self.config, self.name, "train_label")
        cat_cnt = int(self.readField(self.config, self.name, "cat_cnt"))
        labelDH = DataHandler(t, output, cat_cnt, cat_cnt, self.batchsize,
                              numBatch)

        evalFreq = int(self.readField(self.config, self.name, "eval_freq"))

        if evalFreq != 0:
            qsize = int(self.readField(self.config, self.name, "query_size"))
            labelPath = self.readField(self.config, self.name, "val_label")
            label = np.load(labelPath)
            print "path: ", labelPath
            trainLabelPath = self.readField(self.config, self.name,
                                            "train_label")
            trainLabel = np.load(trainLabelPath)
            queryPath = self.readField(self.config, self.name, "query")
            validation = evaluate.Evaluator(queryPath,
                                            label,
                                            os.path.join(outputDir, 'perf'),
                                            self.name,
                                            query_size=qsize,
                                            verbose=self.verbose)
            validation.setTrainLabel(trainLabel)

        testlabelPath = self.readField(self.config, self.name, "test_label")
        testlabel = np.load(testlabelPath)
        print "path: ", testlabelPath
        for i in xrange(self.modalsCnt):
            n = self.names[i]
            s = self.saes[i]

            t = self.readField(s.ae[1].config, s.ae[1].name, "test_data")
            testData.append(gp.garray(np.load(t)))
        test = evaluate.Evaluator(queryPath,
                                  testlabel,
                                  os.path.join(outputDir, 'perf'),
                                  self.name,
                                  query_size=qsize,
                                  verbose=self.verbose)
        test.setTrainLabel(trainLabel)

        print '>>>>>>>>>>>>>>>>>>>>>>pre-training the unfolded network<<<<<<<<<<<<<<<<<<<<'
        diff_cost = 0
        rec_cost = 0.1
        for epoch in range(maxEpoch):
            print 'depth is: ', self.jdepth - 1
            #             perf=np.zeros( nMetric)
            perf = 0
            for i in xrange(self.modalsCnt):
                trainDH[i].reset()
            labelDH.reset()

            print "epoch: ", epoch
            for i in range(numBatch):

                trainbatch = []

                for m in xrange(self.modalsCnt):
                    trainbatch.append(trainDH[m].getOneBatch())
                labelbatch = labelDH.getOneBatch()

                #                 for m in xrange(self.modalsCnt):
                #                     print trainbatch[m].shape
                #                 print labelbatch

                #use imgcost and txt cost
                curr, g, jg = self.trainClassifierOneBatch(trainbatch,
                                                           labelbatch,
                                                           epoch,
                                                           diff_cost=diff_cost,
                                                           recf=rec_cost)
                perf += curr

                for m in xrange(self.modalsCnt):
                    self.saes[m].updateParams(epoch,
                                              g[m],
                                              self.saes[m].ae,
                                              backprop=True)
                if self.has_joint:
                    self.jsae.updateParams(epoch,
                                           jg,
                                           self.jsae.ae,
                                           backprop=True)


#                 perf=self.aggregatePerf(perf, curr)

#             print 'perf is: ', perf
#             if evalFreq!=0 and (1+epoch) % evalFreq == 0:
#                 ele=self.getMMReps(valData)
#                 validation.evalClassification(ele, label, epoch, self.name, metric='euclidean')
#         print 'test:'
#         ele=self.getMMReps(testData)
#         test.evalClassification(ele, testlabel, epoch, self.name, metric='euclidean')
コード例 #2
0
ファイル: msae.py プロジェクト: yysherlock/msae
    def train(self):
        outputPrefix=self.readField(self.config,self.name,"output_directory")
        outputDir=os.path.join(outputPrefix,self.name)
        if not os.path.exists(outputDir):
            os.mkdir(outputDir)
        
        imageinput = self.readField(self.isae.ae[1].config, self.isae.ae[1].name, "train_data")
        textinput = self.readField(self.tsae.ae[1].config, self.tsae.ae[1].name, "train_data")

        if self.readField(self.config, self.name,"extract_reps")=="True":
            imageoutput=self.readField(self.isae.ae[-1].config, self.isae.ae[-1].name, "train_reps")
            textoutput=self.readField(self.tsae.ae[-1].config, self.tsae.ae[-1].name, "train_reps")
        else:
            imageoutput=None
            textoutput=None

        maxEpoch = int(self.readField(self.config, self.name, "max_epoch"))
        trainSize=int(self.readField(self.config, self.name, "train_size"))
        numBatch = int(trainSize / self.batchsize)
 
        normalizeImg=self.str2bool(self.readField(self.config, self.name, "normalize"))
        imgTrainDH=DataHandler(imageinput, imageoutput, self.isae.ae[1].vDim, self.isae.ae[-1].hDim, self.batchsize, numBatch,normalizeImg)
        txtTrainDH=DataHandler(textinput, textoutput, self.tsae.ae[1].vDim, self.tsae.ae[-1].hDim, self.batchsize, numBatch)

        showFreq = int(self.readField(self.config, self.name, "show_freq"))
        if showFreq > 0:
            visDir = os.path.join(outputDir, "vis")
            if not os.path.exists(visDir):
                os.makedirs(visDir)

        evalFreq = int(self.readField(self.config, self.name, "eval_freq"))
        if evalFreq!=0:
            qsize=int(self.readField(self.config, self.name, "query_size"))
            labelPath=self.readField(self.config,self.name,"label")
            label=np.load(labelPath)
            queryPath=self.readField(self.config, self.name, "query")
            validation=evaluate.Evaluator(queryPath,label,os.path.join(outputDir,'perf'), self.name, query_size=qsize,verbose=self.verbose)
            validateImagepath = self.readField(self.isae.ae[1].config, self.isae.ae[1].name, "validation_data")
            validateTextpath = self.readField(self.tsae.ae[1].config, self.tsae.ae[1].name, "validation_data")
            validateImgData = gp.garray(np.load(validateImagepath))
            if normalizeImg:
                validateImgData=imgTrainDH.doNormalization(validateImgData)
            validateTxtData = gp.garray(np.load(validateTextpath))
        else:
            print "Warning: no evluation setting!"

        nCommon, nMetric, title=self.getDisplayFields()
        if self.verbose:
            print title
 
        for epoch in range(maxEpoch):
            perf=np.zeros( nMetric)
            epoch1, imgcost, txtcost, diffcost=self.checkPath(epoch)
            imgTrainDH.reset()
            txtTrainDH.reset()
            for i in range(numBatch):
                img = imgTrainDH.getOneBatch() 
                txt = txtTrainDH.getOneBatch()
                curr= self.trainOneBatch(img, txt, epoch1, imgcost, txtcost, diffcost)
                perf=self.aggregatePerf(perf, curr)

            if evalFreq!=0 and (1+epoch) % evalFreq == 0:
                imgcode,txtcode=self.getReps(validateImgData, validateTxtData)
                validation.evalCrossModal(imgcode,txtcode,epoch,'V')

            if showFreq != 0 and (1+epoch) % showFreq == 0:
                imgcode,txtcode=self.getReps(validateImgData, validateTxtData)
                np.save(os.path.join(visDir,'%simg' % str((epoch+1)/showFreq)),imgcode)
                np.save(os.path.join(visDir,'%stxt' % str((epoch+1)/showFreq)),txtcode)

            if self.verbose:
                self.printEpochInfo(epoch, perf, nCommon)

        if self.readField(self.config, self.name, "checkpoint")=="True":
            self.doCheckpoint(outputDir)

        if self.readField(self.config, self.name,"extract_reps")=="True":
            if evalFreq!=0:
                self.extractValidationReps(validateImgData, validateTxtData, "validation_data","validation_reps")
            self.extractTrainReps(imgTrainDH, txtTrainDH, numBatch)

        self.saveConfig(outputDir)
コード例 #3
0
    def train(self):
        outputPrefix = self.readField(self.config, self.name,
                                      "output_directory")
        outputDir = os.path.join(outputPrefix, self.name)
        if not os.path.exists(outputDir):
            os.makedirs(outputDir)

        showFreq = int(self.readField(self.config, self.name, "show_freq"))
        if showFreq > 0:
            visDir = os.path.join(outputDir, 'vis')
            if not os.path.exists(visDir):
                os.mkdir(visDir)
        #do normalization for images if they are not normalized before
        normalize = self.str2bool(
            self.readField(self.config, self.name, "normalize"))
        trainDataSize = int(
            self.readField(self.config, self.name, "train_size"))
        numBatch = trainDataSize / self.batchsize
        trainDataPath = self.readField(self.config, self.name, "train_data")
        if self.readField(self.config, self.name, "extract_reps") == "True":
            trainRepsPath = self.readField(self.config, self.name,
                                           "train_reps")
        else:
            trainRepsPath = None
        trainDataLoader = DataHandler(trainDataPath, trainRepsPath, self.vDim,
                                      self.hDim, self.batchsize, numBatch,
                                      normalize)

        evalFreq = int(self.readField(self.config, self.name, 'eval_freq'))
        if evalFreq != 0:
            qsize = int(self.readField(self.config, self.name, "query_size"))
            evalPath = self.readField(self.config, self.name,
                                      "validation_data")
            labelPath = self.readField(self.config, self.name, "label")
            queryPath = self.readField(self.config, self.name, "query")
            label = np.load(labelPath)
            eval = Evaluator(queryPath,
                             label,
                             os.path.join(outputDir, 'perf'),
                             self.name,
                             query_size=qsize,
                             verbose=self.verbose)
            validation_data = gp.garray(np.load(evalPath))
            if normalize:
                validation_data = trainDataLoader.doNormalization(
                    validation_data)

        maxEpoch = int(self.readField(self.config, self.name, "max_epoch"))

        nCommon, nMetric, title = self.getDisplayFields()
        if self.verbose:
            print title
        for epoch in range(maxEpoch):
            perf = np.zeros(nMetric)
            trainDataLoader.reset()
            for i in range(numBatch):
                batch = trainDataLoader.getOneBatch()
                curr = self.trainOneBatch(batch, epoch, computeStat=True)
                perf = self.aggregatePerf(perf, curr)

            if showFreq != 0 and (1 + epoch) % showFreq == 0:
                validation_code = self.getReps(validation_data)
                np.save(os.path.join(visDir, '%dvis' % (1 + epoch)),
                        validation_code)
            if evalFreq != 0 and (1 + epoch) % evalFreq == 0:
                validation_code = self.getReps(validation_data)
                eval.evalSingleModal(validation_code, epoch, self.name + 'V')
                validation_code = None
            if self.verbose:
                self.printEpochInfo(epoch, perf, nCommon)

        if self.readField(self.config, self.name, "checkpoint") == "True":
            self.doCheckpoint(outputDir)

        if self.readField(self.config, self.name, "extract_reps") == "True":
            if evalFreq != 0:
                validation_reps_path = self.readField(self.config, self.name,
                                                      "validation_reps")
                self.extractValidationReps(validation_data,
                                           validation_reps_path)
            self.extractTrainReps(trainDataLoader, numBatch)

        self.saveConfig(outputDir)
コード例 #4
0
    def train(self):
        outputPrefix = self.readField(self.config, self.name,
                                      "output_directory")
        outputDir = os.path.join(outputPrefix, self.name)
        if not os.path.exists(outputDir):
            os.mkdir(outputDir)

        imageinput = self.readField(self.isae.ae[1].config,
                                    self.isae.ae[1].name, "train_data")
        textinput = self.readField(self.tsae.ae[1].config,
                                   self.tsae.ae[1].name, "train_data")

        if self.readField(self.config, self.name, "extract_reps") == "True":
            imageoutput = self.readField(self.isae.ae[-1].config,
                                         self.isae.ae[-1].name, "train_reps")
            textoutput = self.readField(self.tsae.ae[-1].config,
                                        self.tsae.ae[-1].name, "train_reps")
        else:
            imageoutput = None
            textoutput = None

        maxEpoch = int(self.readField(self.config, self.name, "max_epoch"))
        trainSize = int(self.readField(self.config, self.name, "train_size"))
        numBatch = int(trainSize / self.batchsize)

        normalizeImg = self.str2bool(
            self.readField(self.config, self.name, "normalize"))
        imgTrainDH = DataHandler(imageinput, imageoutput, self.isae.ae[1].vDim,
                                 self.isae.ae[-1].hDim, self.batchsize,
                                 numBatch, normalizeImg)
        txtTrainDH = DataHandler(textinput, textoutput, self.tsae.ae[1].vDim,
                                 self.tsae.ae[-1].hDim, self.batchsize,
                                 numBatch)

        showFreq = int(self.readField(self.config, self.name, "show_freq"))
        if showFreq > 0:
            visDir = os.path.join(outputDir, "vis")
            if not os.path.exists(visDir):
                os.makedirs(visDir)

        evalFreq = int(self.readField(self.config, self.name, "eval_freq"))
        if evalFreq != 0:
            qsize = int(self.readField(self.config, self.name, "query_size"))
            labelPath = self.readField(self.config, self.name, "label")
            label = np.load(labelPath)
            queryPath = self.readField(self.config, self.name, "query")
            validation = evaluate.Evaluator(queryPath,
                                            label,
                                            os.path.join(outputDir, 'perf'),
                                            self.name,
                                            query_size=qsize,
                                            verbose=self.verbose)
            validateImagepath = self.readField(self.isae.ae[1].config,
                                               self.isae.ae[1].name,
                                               "validation_data")
            validateTextpath = self.readField(self.tsae.ae[1].config,
                                              self.tsae.ae[1].name,
                                              "validation_data")
            validateImgData = gp.garray(np.load(validateImagepath))
            if normalizeImg:
                validateImgData = imgTrainDH.doNormalization(validateImgData)
            validateTxtData = gp.garray(np.load(validateTextpath))
        else:
            print "Warning: no evluation setting!"

        nCommon, nMetric, title = self.getDisplayFields()
        if self.verbose:
            print title

        for epoch in range(maxEpoch):
            perf = np.zeros(nMetric)
            epoch1, imgcost, txtcost, diffcost = self.checkPath(epoch)
            imgTrainDH.reset()
            txtTrainDH.reset()
            for i in range(numBatch):
                img = imgTrainDH.getOneBatch()
                txt = txtTrainDH.getOneBatch()
                curr = self.trainOneBatch(img, txt, epoch1, imgcost, txtcost,
                                          diffcost)
                perf = self.aggregatePerf(perf, curr)

            if evalFreq != 0 and (1 + epoch) % evalFreq == 0:
                imgcode, txtcode = self.getReps(validateImgData,
                                                validateTxtData)
                validation.evalCrossModal(imgcode, txtcode, epoch, 'V')

            if showFreq != 0 and (1 + epoch) % showFreq == 0:
                imgcode, txtcode = self.getReps(validateImgData,
                                                validateTxtData)
                np.save(
                    os.path.join(visDir, '%simg' % str(
                        (epoch + 1) / showFreq)), imgcode)
                np.save(
                    os.path.join(visDir, '%stxt' % str(
                        (epoch + 1) / showFreq)), txtcode)

            if self.verbose:
                self.printEpochInfo(epoch, perf, nCommon)

        if self.readField(self.config, self.name, "checkpoint") == "True":
            self.doCheckpoint(outputDir)

        if self.readField(self.config, self.name, "extract_reps") == "True":
            if evalFreq != 0:
                self.extractValidationReps(validateImgData, validateTxtData,
                                           "validation_data",
                                           "validation_reps")
            self.extractTrainReps(imgTrainDH, txtTrainDH, numBatch)

        self.saveConfig(outputDir)
コード例 #5
0
ファイル: model.py プロジェクト: yysherlock/msae
    def train(self):
        outputPrefix=self.readField(self.config,self.name,"output_directory")
        outputDir=os.path.join(outputPrefix,self.name)
        if not os.path.exists(outputDir):
            os.makedirs(outputDir)

        showFreq = int(self.readField(self.config, self.name, "show_freq"))
        if showFreq > 0:
            visDir = os.path.join(outputDir,'vis')
            if not os.path.exists(visDir):
                os.mkdir(visDir)
        #do normalization for images if they are not normalized before
        normalize=self.str2bool(self.readField(self.config, self.name, "normalize"))
        trainDataSize=int(self.readField(self.config, self.name, "train_size"))
        numBatch = trainDataSize / self.batchsize
        trainDataPath = self.readField(self.config, self.name, "train_data")
        if self.readField(self.config,self.name,"extract_reps")=="True":
            trainRepsPath=self.readField(self.config, self.name, "train_reps")
        else:
            trainRepsPath=None
        trainDataLoader=DataHandler(trainDataPath, trainRepsPath, self.vDim, self.hDim, self.batchsize,numBatch, normalize)

        evalFreq=int(self.readField(self.config,self.name,'eval_freq'))
        if evalFreq!=0:
            qsize=int(self.readField(self.config, self.name, "query_size"))
            evalPath=self.readField(self.config,self.name,"validation_data")
            labelPath=self.readField(self.config,self.name,"label")
            queryPath=self.readField(self.config, self.name, "query")
            label=np.load(labelPath)
            eval=Evaluator(queryPath,label ,os.path.join(outputDir,'perf'), self.name, query_size=qsize,verbose=self.verbose)
            validation_data=gp.garray(np.load(evalPath))
            if normalize:
                validation_data=trainDataLoader.doNormalization(validation_data)

        maxEpoch = int(self.readField(self.config, self.name, "max_epoch"))

        nCommon, nMetric, title=self.getDisplayFields()
        if self.verbose:
            print title
        for epoch in range(maxEpoch):
            perf=np.zeros( nMetric)
            trainDataLoader.reset()
            for i in range(numBatch):
                batch = trainDataLoader.getOneBatch()
                curr = self.trainOneBatch(batch, epoch, computeStat=True)
                perf=self.aggregatePerf(perf, curr)

            if showFreq != 0 and (1+epoch) % showFreq == 0:
                validation_code=self.getReps(validation_data)
                np.save(os.path.join(visDir, '%dvis' % (1+epoch)), validation_code)
            if evalFreq !=0 and (1+epoch) % evalFreq ==0:
                validation_code=self.getReps(validation_data)
                eval.evalSingleModal(validation_code,epoch,self.name+'V')
                validation_code=None
            if self.verbose:
                self.printEpochInfo(epoch,perf,nCommon)

        if self.readField(self.config,self.name,"checkpoint")=="True":
            self.doCheckpoint(outputDir)

        if self.readField(self.config,self.name,"extract_reps")=="True":
            if evalFreq!=0:
                validation_reps_path=self.readField(self.config, self.name, "validation_reps")
                self.extractValidationReps(validation_data, validation_reps_path)
            self.extractTrainReps(trainDataLoader, numBatch)

        self.saveConfig(outputDir)