Example #1
0
    def fit(self, hmodel, Data, LP=None):
        ''' Run VB learning algorithm, fit global parameters of hmodel to Data
            Returns
            --------
            Info : dict of run information, with fields
            * evBound : final ELBO evidence bound
            * status : str message indicating reason for termination
                       {'converged', 'max laps exceeded'}
            * LP : dict of local parameters for final model
        '''
        prevBound = -np.inf
        isConverged = False

        # Save initial state
        self.saveParams(0, hmodel)

        # Custom func hook
        self.eval_custom_func(
            isInitial=1, **makeDictOfAllWorkspaceVars(**vars()))

        self.set_start_time_now()

        # TODO: delete this, this is simply for debugging purposes
        isParallel = True
        self.nDoc = Data.nDoc
        if isParallel:
            # Create a JobQ (to hold tasks to be done)
            # and a ResultsQ (to hold results of completed tasks)
            manager = multiprocessing.Manager()
            self.JobQ = manager.Queue()
            self.ResultQ = manager.Queue()

            # Get the function handles
            makeDataSliceFromSharedMem = Data.getDataSliceFunctionHandle()
            o_calcLocalParams, o_calcSummaryStats = hmodel.obsModel.\
                getLocalAndSummaryFunctionHandles()
            a_calcLocalParams, a_calcSummaryStats = hmodel.allocModel.\
                getLocalAndSummaryFunctionHandles()

            # Create the shared memory
            dataSharedMem = Data.getRawDataAsSharedMemDict()
            aSharedMem = hmodel.allocModel.fillSharedMemDictForLocalStep()
            oSharedMem = hmodel.obsModel.fillSharedMemDictForLocalStep()

            # Create multiple workers
            for uid in range(self.nWorkers):
                SharedMemWorker(uid, self.JobQ, self.ResultQ,
                                makeDataSliceFromSharedMem,
                                o_calcLocalParams,
                                o_calcSummaryStats,
                                a_calcLocalParams,
                                a_calcSummaryStats,
                                dataSharedMem,
                                aSharedMem,
                                oSharedMem,
                                LPkwargs=self.algParamsLP,
                                verbose=1).start()
        else:
            # Need to store shared mem

            aSharedMem = hmodel.allocModel.fillSharedMemDictForLocalStep()
            oSharedMem = hmodel.obsModel.fillSharedMemDictForLocalStep()
            self.dataSharedMem = Data.getRawDataAsSharedMemDict()
            self.makeDataSliceFromSharedMem = Data.getDataSliceFunctionHandle()

        for iterid in range(1, self.algParams['nLap'] + 1):
            lap = self.algParams['startLap'] + iterid
            nLapsCompleted = lap - self.algParams['startLap']
            self.set_random_seed_at_lap(lap)

            if isParallel:
                SS = self.calcLocalParamsAndSummarize(
                    hmodel)  # TODO fill in params

            else:
                SS = self.serialCalcLocalParamsAndSummarize(hmodel)

            # Global/M step
            hmodel.update_global_params(SS)

            # update the memory
            aSharedMem = hmodel.allocModel.fillSharedMemDictForLocalStep(
                aSharedMem)
            oSharedMem = hmodel.obsModel.fillSharedMemDictForLocalStep(
                oSharedMem)

            # ELBO calculation
            evBound = hmodel.calc_evidence(Data=Data, SS=SS)

            if lap > 1.0:
                # Report warning if bound isn't increasing monotonically
                self.verify_evidence(evBound, prevBound)

            # Check convergence of expected counts
            countVec = SS.getCountVec()
            if lap > 1.0:
                isConverged = self.isCountVecConverged(countVec, prevCountVec)
                self.setStatus(lap, isConverged)

            # Display progress
            self.updateNumDataProcessed(Data.get_size())
            if self.isLogCheckpoint(lap, iterid):
                self.printStateToLog(hmodel, evBound, lap, iterid)

            # Save diagnostics and params
            if self.isSaveDiagnosticsCheckpoint(lap, iterid):
                self.saveDiagnostics(lap, SS, evBound)
            if self.isSaveParamsCheckpoint(lap, iterid):
                self.saveParams(lap, hmodel, SS)

            # Custom func hook
            self.eval_custom_func(**makeDictOfAllWorkspaceVars(**vars()))

            if nLapsCompleted >= self.algParams['minLaps'] and isConverged:
                break
            prevBound = evBound
            prevCountVec = countVec.copy()
            # .... end loop over laps

        # Finished! Save, print and exit
        for workerID in range(self.nWorkers):
            # Passing None to JobQ is shutdown signal
            self.JobQ.put(None)
        self.saveParams(lap, hmodel, SS)
        self.printStateToLog(hmodel, evBound, lap, iterid, isFinal=1)
        self.eval_custom_func(
            isFinal=1, **makeDictOfAllWorkspaceVars(**vars()))

        return self.buildRunInfo(evBound=evBound, SS=SS)
Example #2
0
    def fit(self, hmodel, DataIterator):
        ''' Run learning algorithm that fits parameters of hmodel to Data.

        Returns
        --------
        Info : dict of run information.

        Post Condition
        --------
        hmodel updated in place with improved global parameters.
        '''
        self.set_start_time_now()
        # Initialize Progress Tracking vars like nBatch, lapFrac, etc.
        iterid, lapFrac = self.initProgressTrackVars(DataIterator)

        # Keep list of params that should be retained across laps
        mkeys = hmodel.allocModel.get_keys_for_memoized_local_params()
        self.memoLPkeys = mkeys

        # Save initial state
        self.saveParams(lapFrac, hmodel)

        # Custom func hook
        self.eval_custom_func(
            isInitial=1, **makeDictOfAllWorkspaceVars(**vars()))

        # Begin loop over batches of data...
        SS = None
        isConverged = False
        while DataIterator.has_next_batch():

            # Grab new data
            Dchunk = DataIterator.get_next_batch()
            batchID = DataIterator.batchID
            Dchunk.batchID = batchID

            # Update progress-tracking variables
            iterid += 1
            lapFrac = (iterid + 1) * self.lapFracInc
            self.lapFrac = lapFrac
            nLapsCompleted = lapFrac - self.algParams['startLap']
            self.set_random_seed_at_lap(lapFrac)
            if self.doDebugVerbose():
                self.print_msg('========================== lap %.2f batch %d'
                               % (lapFrac, batchID))

            # Local/E step
            self.algParamsLP['lapFrac'] = lapFrac  # logging
            self.algParamsLP['batchID'] = batchID
            LPchunk = self.memoizedLocalStep(hmodel, Dchunk, batchID)
            self.saveDebugStateAtBatch('Estep', batchID, Dchunk=Dchunk,
                                       SS=SS, hmodel=hmodel, LPchunk=LPchunk)

            # Summary step
            SS, SSchunk = self.memoizedSummaryStep(hmodel, SS,
                                                   Dchunk, LPchunk, batchID)
            # Global step
            self.GlobalStep(hmodel, SS, lapFrac)

            # ELBO calculation
            loss = -1 * hmodel.calc_evidence(SS=SS)
            if nLapsCompleted > 1.0:
                # loss decreases monotonically AFTER first lap
                self.verify_monotonic_decrease(loss, prev_loss, lapFrac)

            if self.doDebug() and lapFrac >= 1.0:
                self.verifyELBOTracking(hmodel, SS, loss)

            self.saveDebugStateAtBatch(
                'Mstep', batchID, Dchunk=Dchunk, SSchunk=SSchunk,
                SS=SS, hmodel=hmodel, LPchunk=LPchunk)

            # Assess convergence
            countVec = SS.getCountVec()
            if lapFrac > 1.0:
                isConverged = self.isCountVecConverged(countVec, prevCountVec)
                self.setStatus(lapFrac, isConverged)

            # Display progress
            self.updateNumDataProcessed(Dchunk.get_size())
            if self.isLogCheckpoint(lapFrac, iterid):
                self.printStateToLog(hmodel, loss, lapFrac, iterid)

            # Save diagnostics and params
            if self.isSaveDiagnosticsCheckpoint(lapFrac, iterid):
                self.saveDiagnostics(lapFrac, SS, loss)
            if self.isSaveParamsCheckpoint(lapFrac, iterid):
                self.saveParams(lapFrac, hmodel, SS)

            # Custom func hook
            self.eval_custom_func(**makeDictOfAllWorkspaceVars(**vars()))

            if isConverged and self.isLastBatch(lapFrac) \
               and nLapsCompleted >= self.algParams['minLaps']:
                break
            prevCountVec = countVec.copy()
            prev_loss = loss
            # .... end loop over data

        # Finished! Save, print and exit
        self.printStateToLog(hmodel, loss, lapFrac, iterid, isFinal=1)
        self.saveParams(lapFrac, hmodel, SS)
        self.eval_custom_func(
            isFinal=1, **makeDictOfAllWorkspaceVars(**vars()))

        if hasattr(DataIterator, 'Data'):
            Data = DataIterator.Data
        else:
            Data = DataIterator.getBatch(0)
        return self.buildRunInfo(Data=Data, loss=loss, SS=SS,
                                 LPmemory=self.LPmemory,
                                 SSmemory=self.SSmemory)
Example #3
0
    def fit(self, hmodel, DataIterator, SS=None):
        ''' Run stochastic variational to fit hmodel parameters to Data.

        Returns
        --------
        Info : dict of run information.

        Post Condition
        --------
        hmodel updated in place with improved global parameters.
        '''
        self.set_start_time_now()
        LP = None
        rho = 1.0  # Learning rate
        nBatch = float(DataIterator.nBatch)

        # Set-up progress-tracking variables
        iterid = -1
        lapFrac = np.maximum(0, self.algParams['startLap'] - 1.0 / nBatch)
        if lapFrac > 0:
            # When restarting an existing run,
            #  need to start with last update for final batch from previous lap
            DataIterator.lapID = int(np.ceil(lapFrac)) - 1
            DataIterator.curLapPos = nBatch - 2
            iterid = int(nBatch * lapFrac) - 1

        # Save initial state
        self.saveParams(lapFrac, hmodel)

        # Custom func hook
        self.eval_custom_func(isInitial=1,
                              **makeDictOfAllWorkspaceVars(**vars()))
        ElapsedTimeLogger.writeToLogOnLapCompleted(lapFrac)

        if self.algParams['doMemoELBO']:
            SStotal = None
            SSPerBatch = dict()
        else:
            loss_running_sum = 0
            loss_per_batch = np.zeros(nBatch)
        while DataIterator.has_next_batch():

            # Grab new data
            Dchunk = DataIterator.get_next_batch()
            batchID = DataIterator.batchID
            Dchunk.batchID = batchID

            # Update progress-tracking variables
            iterid += 1
            lapFrac += 1.0 / nBatch
            self.lapFrac = lapFrac
            nLapsCompleted = lapFrac - self.algParams['startLap']
            self.set_random_seed_at_lap(lapFrac)

            # E step
            self.algParamsLP['batchID'] = batchID
            self.algParamsLP['lapFrac'] = lapFrac  # logging
            if batchID in self.LPmemory:
                batchLP = self.load_batch_local_params_from_memory(batchID)
            else:
                batchLP = None
            LP = hmodel.calc_local_params(Dchunk,
                                          batchLP,
                                          doLogElapsedTime=True,
                                          **self.algParamsLP)
            rho = (1 + iterid + self.rhodelay)**(-1.0 * self.rhoexp)
            if self.algParams['doMemoELBO']:
                # SS step. Scale at size of current batch.
                SS = hmodel.get_global_suff_stats(Dchunk,
                                                  LP,
                                                  doLogElapsedTime=True,
                                                  doPrecompEntropy=True)
                if self.algParams['doMemoizeLocalParams']:
                    self.save_batch_local_params_to_memory(batchID, LP)
                # Incremental updates for whole-dataset stats
                # Must happen before applification.
                if batchID in SSPerBatch:
                    SStotal -= SSPerBatch[batchID]
                if SStotal is None:
                    SStotal = SS.copy()
                else:
                    SStotal += SS
                SSPerBatch[batchID] = SS.copy()

                # Scale up to size of whole dataset.
                if hasattr(Dchunk, 'nDoc'):
                    ampF = Dchunk.nDocTotal / float(Dchunk.nDoc)
                    SS.applyAmpFactor(ampF)
                else:
                    ampF = Dchunk.nObsTotal / float(Dchunk.nObs)
                    SS.applyAmpFactor(ampF)
                # M step with learning rate
                hmodel.update_global_params(SS, rho, doLogElapsedTime=True)
                # ELBO step
                assert not SStotal.hasAmpFactor()
                loss = -1 * hmodel.calc_evidence(
                    SS=SStotal,
                    doLogElapsedTime=True,
                    afterGlobalStep=not self.algParams['useSlackTermsInELBO'])
            else:
                # SS step. Scale at size of current batch.
                SS = hmodel.get_global_suff_stats(Dchunk,
                                                  LP,
                                                  doLogElapsedTime=True)

                # Scale up to size of whole dataset.
                if hasattr(Dchunk, 'nDoc'):
                    ampF = Dchunk.nDocTotal / float(Dchunk.nDoc)
                    SS.applyAmpFactor(ampF)
                else:
                    ampF = Dchunk.nObsTotal / float(Dchunk.nObs)
                    SS.applyAmpFactor(ampF)

                # M step with learning rate
                hmodel.update_global_params(SS, rho, doLogElapsedTime=True)

                # ELBO step
                assert SS.hasAmpFactor()
                cur_batch_loss = -1 * hmodel.calc_evidence(
                    Dchunk, SS, LP, doLogElapsedTime=True)
                if loss_per_batch[batchID] != 0:
                    loss_running_sum -= loss_per_batch[batchID]
                loss_running_sum += cur_batch_loss
                loss_per_batch[batchID] = cur_batch_loss
                loss = loss_running_sum / nBatch

            # Display progress
            self.updateNumDataProcessed(Dchunk.get_size())
            if self.isLogCheckpoint(lapFrac, iterid):
                self.printStateToLog(hmodel, loss, lapFrac, iterid, rho=rho)

            # Save diagnostics and params
            if self.isSaveDiagnosticsCheckpoint(lapFrac, iterid):
                self.saveDiagnostics(lapFrac, SS, loss)
            if self.isSaveParamsCheckpoint(lapFrac, iterid):
                self.saveParams(lapFrac, hmodel, tryToSparsifyOutput=1)
                # don't save SS here, since its for one batch only
            self.eval_custom_func(**makeDictOfAllWorkspaceVars(**vars()))

            if self.isLastBatch(lapFrac):
                ElapsedTimeLogger.writeToLogOnLapCompleted(lapFrac)
            # .... end loop over data

        # Finished! Save, print and exit
        self.printStateToLog(hmodel, loss, lapFrac, iterid, isFinal=1)
        self.saveParams(lapFrac, hmodel, SS)
        self.eval_custom_func(isFinal=1,
                              **makeDictOfAllWorkspaceVars(**vars()))

        return self.buildRunInfo(Data=DataIterator, loss=loss, SS=SS)
Example #4
0
    def fit(self, hmodel, Data, LP=None):
        ''' Run VB learning to fit global parameters of hmodel to Data

        Returns
        --------
        Info : dict of run information.

        Post Condition
        --------
        hmodel updated in place with improved global parameters.
        '''
        self.set_start_time_now()
        prev_loss = np.inf
        isConverged = False
        # Save initial state
        self.saveParams(0, hmodel)
        # Custom func hook
        self.eval_custom_func(
            isInitial=1, **makeDictOfAllWorkspaceVars(**vars()))
        for iterid in xrange(1, self.algParams['nLap'] + 1):
            lap = self.algParams['startLap'] + iterid
            nLapsCompleted = lap - self.algParams['startLap']
            self.set_random_seed_at_lap(lap)

            # Local/E step
            self.algParamsLP['lapFrac'] = lap  # logging
            self.algParamsLP['batchID'] = 1
            LP = hmodel.calc_local_params(
                Data, LP, doLogElapsedTime=True, **self.algParamsLP)

            # Summary step
            SS = hmodel.get_global_suff_stats(Data, LP, doLogElapsedTime=True)

            # Global/M step
            hmodel.update_global_params(SS, doLogElapsedTime=True)

            # ELBO calculation
            cur_loss = -1 * hmodel.calc_evidence(Data, SS, LP, doLogElapsedTime=True)
            if lap > 1.0:
                # Report warning if loss function isn't behaving monotonically
                self.verify_monotonic_decrease(cur_loss, prev_loss)

            # Check convergence of expected counts
            countVec = SS.getCountVec()
            if lap > 1.0:
                isConverged = self.isCountVecConverged(countVec, prevCountVec)
                self.setStatus(lap, isConverged)

            # Display progress
            self.updateNumDataProcessed(Data.get_size())
            if self.isLogCheckpoint(lap, iterid):
                self.printStateToLog(hmodel, cur_loss, lap, iterid)

            # Save diagnostics and params
            if self.isSaveDiagnosticsCheckpoint(lap, iterid):
                self.saveDiagnostics(lap, SS, cur_loss)
            if self.isSaveParamsCheckpoint(lap, iterid):
                self.saveParams(lap, hmodel, SS)

            # Custom func hook
            self.eval_custom_func(**makeDictOfAllWorkspaceVars(**vars()))

            # Write elapsed times to disk
            ElapsedTimeLogger.writeToLogOnLapCompleted(lap)

            if nLapsCompleted >= self.algParams['minLaps'] and isConverged:
                break
            prev_loss = cur_loss
            prevCountVec = countVec.copy()
            # .... end loop over laps

        # Finished! Save, print and exit
        self.saveParams(lap, hmodel, SS)
        self.printStateToLog(hmodel, cur_loss, lap, iterid, isFinal=1)
        self.eval_custom_func(
            isFinal=1, **makeDictOfAllWorkspaceVars(**vars()))

        return self.buildRunInfo(Data=Data, loss=cur_loss, SS=SS, LP=LP)
Example #5
0
    def fit(self, hmodel, DataIterator, LP=None, **kwargs):
        ''' Run learning algorithm that fits parameters of hmodel to Data.

        Returns
        --------
        Info : dict of run information.

        Post Condition
        --------
        hmodel updated in place with improved global parameters.
        '''
        origmodel = hmodel

        self.ActiveIDVec = np.arange(hmodel.obsModel.K)
        self.maxUID = self.ActiveIDVec.max()

        # Initialize Progress Tracking vars like nBatch, lapFrac, etc.
        iterid, lapFrac = self.initProgressTrackVars(DataIterator)

        # Save initial state
        self.saveParams(lapFrac, hmodel)

        # Custom func hook
        self.eval_custom_func(
            isInitial=1, **makeDictOfAllWorkspaceVars(**vars()))

        # Setup workers for parallel runs
        if self.nWorkers > 0:
            JobQ, ResultQ, aSharedMem, oSharedMem = setupQueuesAndWorkers(
                DataIterator, hmodel,
                nWorkers=self.nWorkers,
                algParamsLP=self.algParamsLP)
            self.JobQ = JobQ
            self.ResultQ = ResultQ

        # Prep for birth
        BirthPlans = list()
        BirthResults = list()
        prevBirthResults = list()

        # Prep for merge
        MergePrepInfo = dict()
        if self.hasMove('merge'):
            mergeStartLap = self.algParams['merge']['mergeStartLap']
        else:
            mergeStartLap = 0
        order = None

        # Prep for delete
        DeletePlans = list()

        # Begin loop over batches of data...
        SS = None
        isConverged = False
        self.set_start_time_now()
        while DataIterator.has_next_batch():

            batchID = DataIterator.get_next_batch(batchIDOnly=1)

            # Update progress-tracking variables
            iterid += 1
            lapFrac = (iterid + 1) * self.lapFracInc
            self.lapFrac = lapFrac
            self.set_random_seed_at_lap(lapFrac)
            if self.doDebugVerbose():
                self.print_msg('========================== lap %.2f batch %d'
                               % (lapFrac, batchID))

            # Prepare for merges
            if self.hasMove('merge') and self.doMergePrepAtLap(lapFrac):
                MergePrepInfo = self.preparePlansForMerge(
                    hmodel, SS, MergePrepInfo,
                    order=order,
                    BirthResults=BirthResults,
                    lapFrac=lapFrac)
            elif self.isFirstBatch(lapFrac):
                if self.doMergePrepAtLap(lapFrac + 1):
                    MergePrepInfo = dict(
                        mergePairSelection=self.algParams[
                            'merge']['mergePairSelection'])
                else:
                    MergePrepInfo = dict()

            # Reset selection terms to zero
            if self.isFirstBatch(lapFrac):
                if SS is not None and SS.hasSelectionTerms():
                    SS._SelectTerms.setAllFieldsToZero()

            # Update shared memory with new global params
            if self.nWorkers > 0:
                aSharedMem = hmodel.allocModel.fillSharedMemDictForLocalStep(
                    aSharedMem)
                oSharedMem = hmodel.obsModel.fillSharedMemDictForLocalStep(
                    oSharedMem)

            # Local/Summary step for current batch
            self.algParamsLP['lapFrac'] = lapFrac  # for logging
            self.algParamsLP['batchID'] = batchID
            if self.nWorkers > 0:
                SSchunk = self.calcLocalParamsAndSummarize_parallel(
                    DataIterator, hmodel,
                    MergePrepInfo=MergePrepInfo,
                    batchID=batchID, lapFrac=lapFrac)
            else:
                SSchunk = self.calcLocalParamsAndSummarize_main(
                    DataIterator, hmodel,
                    MergePrepInfo=MergePrepInfo,
                    batchID=batchID, lapFrac=lapFrac)

            self.saveDebugStateAtBatch(
                'Estep', batchID, SSchunk=SSchunk, SS=SS, hmodel=hmodel)

            # Summary step for whole-dataset stats
            # (does incremental update)
            SS = self.memoizedSummaryStep(hmodel, SS, SSchunk, batchID)

            # Global step
            self.GlobalStep(hmodel, SS, lapFrac)

            # ELBO calculation
            if self.isLastBatch(lapFrac):
                # after seeing all data, ELBO will be ready
                self.ELBOReady = True
            if self.ELBOReady:
                evBound = hmodel.calc_evidence(SS=SS)

            # Merge move!
            if self.hasMove('merge') and self.isLastBatch(lapFrac) \
                    and lapFrac > mergeStartLap:
                hmodel, SS, evBound = self.run_many_merge_moves(
                    hmodel, SS, evBound, lapFrac, MergePrepInfo)
                # Cancel all planned deletes if merges were accepted.
                if hasattr(self, 'MergeLog') and len(self.MergeLog) > 0:
                    DeletePlans = []
                    # Update memoized stats for each batch
                    self.fastForwardMemory(Kfinal=SS.K)
                    if hasattr(SS, 'mPairIDs'):
                        del SS.mPairIDs

            # Shuffle : Rearrange topic order (big to small)
            if self.hasMove('shuffle') and self.isLastBatch(lapFrac):
                order = np.argsort(-1 * SS.getCountVec())
                sortedalready = np.arange(SS.K)
                if np.allclose(order, sortedalready):
                    order = None  # Already sorted, do nothing!
                else:
                    self.ActiveIDVec = self.ActiveIDVec[order]
                    SS.reorderComps(order)
                    assert np.allclose(SS.uIDs, self.ActiveIDVec)
                    hmodel.update_global_params(SS)
                    evBound = hmodel.calc_evidence(SS=SS)
                    # Update tracked target stats for any upcoming deletes
                    for DPlan in DeletePlans:
                        if self.hasMove('merge'):
                            assert len(self.MergeLog) == 0
                        DPlan['targetSS'].reorderComps(order)
                        targetSSbyBatch = DPlan['targetSSByBatch']
                        for batchID in targetSSbyBatch:
                            targetSSbyBatch[batchID].reorderComps(order)
                    # Update memoized stats for each batch
                    self.fastForwardMemory(Kfinal=SS.K, order=order)

            # ELBO calculation
            nLapsCompleted = lapFrac - self.algParams['startLap']
            if nLapsCompleted > 1.0:
                # evBound increases monotonically AFTER first lap
                # verify_evidence warns if this isn't happening
                self.verify_evidence(evBound, prevBound, lapFrac)

            if self.doDebug() and lapFrac >= 1.0:
                self.verifyELBOTracking(hmodel, SS, evBound, order=order)

            self.saveDebugStateAtBatch(
                'Mstep', batchID, SSchunk=SSchunk, SS=SS, hmodel=hmodel)

            # Assess convergence
            countVec = SS.getCountVec()
            if lapFrac > 1.0:
                isConverged = self.isCountVecConverged(countVec, prevCountVec)
                hasMoreMoves = self.hasMoreReasonableMoves(lapFrac, SS)
                isConverged = isConverged and not hasMoreMoves
                self.setStatus(lapFrac, isConverged)

            # Display progress
            if self.isLogCheckpoint(lapFrac, iterid):
                self.printStateToLog(hmodel, evBound, lapFrac, iterid)

            # Save diagnostics and params
            if self.isSaveDiagnosticsCheckpoint(lapFrac, iterid):
                self.saveDiagnostics(lapFrac, SS, evBound)
            if self.isSaveParamsCheckpoint(lapFrac, iterid):
                self.saveParams(lapFrac, hmodel, SS)

            # Custom func hook
            self.eval_custom_func(**makeDictOfAllWorkspaceVars(**vars()))

            if isConverged and self.isLastBatch(lapFrac) \
               and nLapsCompleted >= self.algParams['minLaps']:
                break
            prevCountVec = countVec.copy()
            prevBound = evBound
            # .... end loop over data

        # Finished! Save, print and exit
        for workerID in range(self.nWorkers):
            # Passing None to JobQ is shutdown signal
            self.JobQ.put(None)

        self.printStateToLog(hmodel, evBound, lapFrac, iterid, isFinal=1)
        self.saveParams(lapFrac, hmodel, SS)
        self.eval_custom_func(
            isFinal=1, **makeDictOfAllWorkspaceVars(**vars()))

        # Births and merges require copies of original model object
        #  we need to make sure original reference has updated parameters, etc.
        if id(origmodel) != id(hmodel):
            origmodel.allocModel = hmodel.allocModel
            origmodel.obsModel = hmodel.obsModel

        # Return information about this run
        return self.buildRunInfo(evBound=evBound, SS=SS,
                                 SSmemory=self.SSmemory)
Example #6
0
    def fit(self, hmodel, DataIterator, LP=None, **kwargs):
        ''' Run learning algorithm that fits parameters of hmodel to Data.

        Returns
        --------
        Info : dict of run information.

        Post Condition
        --------
        hmodel updated in place with improved global parameters.
        '''
        # Initialize Progress Tracking vars like nBatch, lapFrac, etc.
        iterid, lapFrac = self.initProgressTrackVars(DataIterator)

        # Save initial state
        self.saveParams(lapFrac, hmodel)

        # Custom func hook
        self.eval_custom_func(isInitial=1,
                              **makeDictOfAllWorkspaceVars(**vars()))

        # Setup workers for parallel runs
        if self.nWorkers > 0:
            JobQ, ResultQ, aSharedMem, oSharedMem = setupQueuesAndWorkers(
                DataIterator,
                hmodel,
                nWorkers=self.nWorkers,
                algParamsLP=self.algParamsLP)
            self.JobQ = JobQ
            self.ResultQ = ResultQ

        # Begin loop over batches of data...
        SS = None
        isConverged = False
        self.set_start_time_now()
        while DataIterator.has_next_batch():

            batchID = DataIterator.get_next_batch(batchIDOnly=1)

            # Update progress-tracking variables
            iterid += 1
            lapFrac = (iterid + 1) * self.lapFracInc
            self.lapFrac = lapFrac
            self.set_random_seed_at_lap(lapFrac)
            if self.doDebugVerbose():
                self.print_msg('========================== lap %.2f batch %d' %
                               (lapFrac, batchID))

            # Local/Summary step for current batch
            self.algParamsLP['lapFrac'] = lapFrac  # for logging
            self.algParamsLP['batchID'] = batchID

            if self.nWorkers > 0:
                SSchunk = self.calcLocalParamsAndSummarize_parallel(
                    DataIterator, hmodel, batchID=batchID, lapFrac=lapFrac)
            else:
                SSchunk = self.calcLocalParamsAndSummarize_main(
                    DataIterator, hmodel, batchID=batchID, lapFrac=lapFrac)

            self.saveDebugStateAtBatch('Estep',
                                       batchID,
                                       SSchunk=SSchunk,
                                       SS=SS,
                                       hmodel=hmodel)

            # Summary step for whole-dataset stats
            # (does incremental update)
            SS = self.memoizedSummaryStep(hmodel, SS, SSchunk, batchID)

            # Global step
            self.GlobalStep(hmodel, SS, lapFrac)
            if self.nWorkers > 0:
                aSharedMem = hmodel.allocModel.fillSharedMemDictForLocalStep(
                    aSharedMem)
                oSharedMem = hmodel.obsModel.fillSharedMemDictForLocalStep(
                    oSharedMem)

            # ELBO calculation
            evBound = hmodel.calc_evidence(SS=SS)
            nLapsCompleted = lapFrac - self.algParams['startLap']
            if nLapsCompleted > 1.0:
                # evBound increases monotonically AFTER first lap
                # verify_evidence warns if this isn't happening
                self.verify_evidence(evBound, prevBound, lapFrac)

            if self.doDebug() and lapFrac >= 1.0:
                self.verifyELBOTracking(hmodel, SS, evBound)

            self.saveDebugStateAtBatch('Mstep',
                                       batchID,
                                       SSchunk=SSchunk,
                                       SS=SS,
                                       hmodel=hmodel)

            # Assess convergence
            countVec = SS.getCountVec()
            if lapFrac > 1.0:
                isConverged = self.isCountVecConverged(countVec, prevCountVec)
                self.setStatus(lapFrac, isConverged)

            # Display progress
            if self.isLogCheckpoint(lapFrac, iterid):
                self.printStateToLog(hmodel, evBound, lapFrac, iterid)

            # Save diagnostics and params
            if self.isSaveDiagnosticsCheckpoint(lapFrac, iterid):
                self.saveDiagnostics(lapFrac, SS, evBound)
            if self.isSaveParamsCheckpoint(lapFrac, iterid):
                self.saveParams(lapFrac, hmodel, SS)

            # Custom func hook
            self.eval_custom_func(**makeDictOfAllWorkspaceVars(**vars()))

            if isConverged and self.isLastBatch(lapFrac) \
               and nLapsCompleted >= self.algParams['minLaps']:
                break
            prevCountVec = countVec.copy()
            prevBound = evBound
            # .... end loop over data

        # Finished! Save, print and exit
        for workerID in range(self.nWorkers):
            # Passing None to JobQ is shutdown signal
            self.JobQ.put(None)

        self.printStateToLog(hmodel, evBound, lapFrac, iterid, isFinal=1)
        self.saveParams(lapFrac, hmodel, SS)
        self.eval_custom_func(isFinal=1,
                              **makeDictOfAllWorkspaceVars(**vars()))

        if hasattr(DataIterator, 'Data'):
            Data = DataIterator.Data
        else:
            Data = DataIterator.getBatch(0)
        return self.buildRunInfo(Data=Data,
                                 evBound=evBound,
                                 SS=SS,
                                 SSmemory=self.SSmemory)