def cleanup_mergenewcompsintoexisting(Data, expandModel, xSS, xLP,
                                            Korig=0, **kwargs):
  import MergeMove

  Kexpand = xSS.K
  mPairIDs = MergeMove.preselect_all_merge_candidates(
              expandModel, xSS, randstate=kwargs['randstate'],
              preselectroutine=kwargs['cleanuppreselectroutine'], 
              mergePerLap=kwargs['cleanupNumMergeTrials']*(Kexpand-Korig),
              compIDs=range(Korig, Kexpand))
  mPairIDsOrig = [x for x in mPairIDs]  

  if xLP['K'] != xSS.K:
    # Provided local params are stale, so need to recompute!
    xLP = expandModel.calc_local_params(Data)
  xSS = expandModel.get_global_suff_stats(Data, xLP,
                  doPrecompEntropy=True, doPrecompMergeEntropy=True,
                  mPairIDs=mPairIDs)

  assert 'randstate' in kwargs
  mergexModel, mergexSS, mergexEv, MTracker = MergeMove.run_many_merge_moves(
                               expandModel, Data, xSS,
                               nMergeTrials=xSS.K**2, 
                               mPairIDs=mPairIDs,
                               **kwargs)

  for x in MTracker.acceptedOrigIDs:
    assert x in mPairIDsOrig
  
  targetSS = xSS
  targetSS.setELBOFieldsToZero()
  targetSS.setMergeFieldsToZero()

  return mergexSS, mergexEv
Example #2
0
def clean_up_expanded_suff_stats(targetData, curModel, targetSS,
                                  randstate=np.random, **kwargs):
  ''' Create expanded model combining original and brand-new comps
        and try to identify brand-new comps that are redundant copies of   
        originals and can be removed 
  '''
  import MergeMove
  Korig = curModel.allocModel.K
  origLP = curModel.calc_local_params(targetData)
  expandSS = curModel.get_global_suff_stats(targetData, origLP) 
  expandSS.insertComps(targetSS)
  expandModel = curModel.copy()
  expandModel.update_global_params(expandSS)

  expandLP = expandModel.calc_local_params(targetData)
  expandSS = expandModel.get_global_suff_stats(targetData, expandLP,
                  doPrecompEntropy=True, doPrecompMergeEntropy=True)
  Kexpand = expandSS.K

  mPairIDs = MergeMove.preselect_all_merge_candidates(
              expandModel, expandSS, randstate=np.random,
              preselectroutine=kwargs['cleanuppreselectroutine'], 
              mergePerLap=kwargs['cleanupNumMergeTrials']*(Kexpand-Korig),
              compIDs=range(Korig, Kexpand))

  mPairIDsOrig = [x for x in mPairIDs]

  xModel, xSS, xEv, MTracker = MergeMove.run_many_merge_moves(
                               expandModel, targetData, expandSS,
                               nMergeTrials=expandSS.K**2, 
                               mPairIDs=mPairIDs,
                               randstate=randstate, **kwargs)

  if kwargs['doVizBirth']:
    viz_birth_proposal_2D(expandModel, xModel, None, None,
                          title1='expanded model',
                          title2='after merge')

  for x in MTracker.acceptedOrigIDs:
    assert x in mPairIDsOrig
  
  if kwargs['cleanupModifyOrigComps']:
    targetSS = xSS
    targetSS.setELBOFieldsToZero()
    targetSS.setMergeFieldsToZero()
  else:
    # Remove from targetSS all the comps whose merges were accepted
    kBList = [kB for kA,kB in MTracker.acceptedOrigIDs]

    if len(kBList) == targetSS.K:
      msg = 'BIRTH terminated. all new comps redundant with originals.'
      raise BirthProposalError(msg)
    for kB in reversed(sorted(kBList)):
      ktarget = kB - Korig
      if ktarget >= 0:
        targetSS.removeComp(ktarget)
  return targetSS
def cleanup_mergenewcompsonly(Data, expandModel, LP=None, 
                                    Korig=0, **kwargs):
  import MergeMove

  mergeModel = expandModel
  Ktotal = mergeModel.obsModel.K

  # Perform many merges among the fresh components
  for trial in xrange(10):
    mPairIDs = list()
    for kA in xrange(Korig, Ktotal):
      for kB in xrange(kA+1, Ktotal):
        mPairIDs.append( (kA,kB) )

    if trial == 0 and LP is not None:
      mLP = LP
    else:
      mLP = mergeModel.calc_local_params(Data)
    mLP['K'] = mergeModel.allocModel.K
    mSS = mergeModel.get_global_suff_stats(Data, mLP,
                    doPrecompEntropy=True, doPrecompMergeEntropy=True,
                    mPairIDs=mPairIDs)

    assert 'randstate' in kwargs
    mergeModel, mergeSS, mergeEv, MTracker = MergeMove.run_many_merge_moves(
                               mergeModel, Data, mSS, 
                               nMergeTrials=len(mPairIDs),
                               mPairIDs=mPairIDs, 
                               **kwargs)
    if mergeSS.K == Ktotal:
      break # no merges happened, so quit trying
    Ktotal = mergeSS.K


  return mergeModel, mergeSS, mLP, mergeEv
Example #4
0
def clean_up_fresh_model(targetData, curModel, freshModel, 
                            randstate=np.random, **mergeKwArgs):
  ''' Returns set of suff stats that summarize the fresh model
      1) verifies fresh model improves over default (single component) model
      2) perform merges within fresh, requiring improvement on target data
      3) perform merges within full (combined) model,
            aiming only to remove the new/fresh comps
  '''
  import MergeMove

  # Perform many merges among the fresh components
  for trial in xrange(10):
    targetLP = freshModel.calc_local_params(targetData)
    targetSS = freshModel.get_global_suff_stats(targetData, targetLP,
                    doPrecompEntropy=True, doPrecompMergeEntropy=True)
    prevK = targetSS.K
    freshModel, targetSS, freshEvBound, MTracker = MergeMove.run_many_merge_moves(
                               freshModel, targetData, targetSS,
                               nMergeTrials=targetSS.K**2, 
                               randstate=randstate, 
                               **mergeKwArgs)
    if targetSS.K == prevK:
      break # no merges happened, so quit trying

  if targetSS.K < 2:
    return targetSS # quit early, will reject

  # Create K=1 model
  singleModel = curModel.copy()
  singleSS = targetSS.getComp(0, doCollapseK1=False)
  singleModel.update_global_params(singleSS)
  singleLP = singleModel.calc_local_params(targetData)
  singleSS = singleModel.get_global_suff_stats(targetData, singleLP,
                  doPrecompEntropy=True)
  singleModel.update_global_params(singleSS) # make it reflect targetData

  # Calculate evidence under K=1 model
  singleEvBound = singleModel.calc_evidence(SS=singleSS)
 
  # Verify fresh model preferred over K=1 model
  improveEvBound = freshEvBound - singleEvBound
  if improveEvBound <= 0 or improveEvBound < 0.00001 * abs(singleEvBound):
    msg = "BIRTH terminated. Not better than single component on target data."
    msg += "\n  fresh  | K=%3d | %.7e" % (targetSS.K, freshEvBound)
    msg += "\n  single | K=%3d | %.7e" % (singleSS.K, singleEvBound)
    raise BirthProposalError(msg)

  # Verify fresh model improves over current model 
  curLP = curModel.calc_local_params(targetData)
  curSS = curModel.get_global_suff_stats(targetData, curLP, doPrecompEntropy=True)
  curEvBound = curModel.calc_evidence(SS=curSS)
  improveEvBound = freshEvBound - curEvBound
  if improveEvBound <= 0 or improveEvBound < 0.00001 * abs(curEvBound):
    msg = "BIRTH terminated. Not better than current model on target data."
    msg += "\n  fresh | K=%3d | %.7e" % (targetSS.K, freshEvBound)
    msg += "\n  cur   | K=%3d | %.7e" % (curSS.K, curEvBound)
    raise BirthProposalError(msg)

  return targetSS
Example #5
0
def cleanup_mergenewcompsintoexisting(Data,
                                      expandModel,
                                      xSS,
                                      xLP,
                                      Korig=0,
                                      **kwargs):
    import MergeMove

    Kexpand = xSS.K
    mPairIDs = MergeMove.preselect_all_merge_candidates(
        expandModel,
        xSS,
        randstate=kwargs['randstate'],
        preselectroutine=kwargs['cleanuppreselectroutine'],
        mergePerLap=kwargs['cleanupNumMergeTrials'] * (Kexpand - Korig),
        compIDs=range(Korig, Kexpand))
    mPairIDsOrig = [x for x in mPairIDs]

    if xLP['K'] != xSS.K:
        # Provided local params are stale, so need to recompute!
        xLP = expandModel.calc_local_params(Data)
    xSS = expandModel.get_global_suff_stats(Data,
                                            xLP,
                                            doPrecompEntropy=True,
                                            doPrecompMergeEntropy=True,
                                            mPairIDs=mPairIDs)

    assert 'randstate' in kwargs
    mergexModel, mergexSS, mergexEv, MTracker = MergeMove.run_many_merge_moves(
        expandModel,
        Data,
        xSS,
        nMergeTrials=xSS.K**2,
        mPairIDs=mPairIDs,
        **kwargs)

    for x in MTracker.acceptedOrigIDs:
        assert x in mPairIDsOrig

    targetSS = xSS
    targetSS.setELBOFieldsToZero()
    targetSS.setMergeFieldsToZero()

    return mergexSS, mergexEv
Example #6
0
def cleanup_mergenewcompsonly(Data, expandModel, LP=None, Korig=0, **kwargs):
    import MergeMove

    mergeModel = expandModel
    Ktotal = mergeModel.obsModel.K

    # Perform many merges among the fresh components
    for trial in xrange(10):
        mPairIDs = list()
        for kA in xrange(Korig, Ktotal):
            for kB in xrange(kA + 1, Ktotal):
                mPairIDs.append((kA, kB))

        if trial == 0 and LP is not None:
            mLP = LP
        else:
            mLP = mergeModel.calc_local_params(Data)
        mLP['K'] = mergeModel.allocModel.K
        mSS = mergeModel.get_global_suff_stats(Data,
                                               mLP,
                                               doPrecompEntropy=True,
                                               doPrecompMergeEntropy=True,
                                               mPairIDs=mPairIDs)

        assert 'randstate' in kwargs
        mergeModel, mergeSS, mergeEv, MTracker = MergeMove.run_many_merge_moves(
            mergeModel,
            Data,
            mSS,
            nMergeTrials=len(mPairIDs),
            mPairIDs=mPairIDs,
            **kwargs)
        if mergeSS.K == Ktotal:
            break  # no merges happened, so quit trying
        Ktotal = mergeSS.K

    return mergeModel, mergeSS, mLP, mergeEv
Example #7
0
def clean_up_fresh_model(targetData,
                         curModel,
                         freshModel,
                         randstate=np.random,
                         **mergeKwArgs):
    ''' Returns set of suff stats that summarize the fresh model
      1) verifies fresh model improves over default (single component) model
      2) perform merges within fresh, requiring improvement on target data
      3) perform merges within full (combined) model,
            aiming only to remove the new/fresh comps
  '''
    import MergeMove

    # Perform many merges among the fresh components
    for trial in xrange(10):
        targetLP = freshModel.calc_local_params(targetData)
        targetSS = freshModel.get_global_suff_stats(targetData,
                                                    targetLP,
                                                    doPrecompEntropy=True,
                                                    doPrecompMergeEntropy=True)
        prevK = targetSS.K
        freshModel, targetSS, freshEvBound, MTracker = MergeMove.run_many_merge_moves(
            freshModel,
            targetData,
            targetSS,
            nMergeTrials=targetSS.K**2,
            randstate=randstate,
            **mergeKwArgs)
        if targetSS.K == prevK:
            break  # no merges happened, so quit trying

    if targetSS.K < 2:
        return targetSS  # quit early, will reject

    # Create K=1 model
    singleModel = curModel.copy()
    singleSS = targetSS.getComp(0, doCollapseK1=False)
    singleModel.update_global_params(singleSS)
    singleLP = singleModel.calc_local_params(targetData)
    singleSS = singleModel.get_global_suff_stats(targetData,
                                                 singleLP,
                                                 doPrecompEntropy=True)
    singleModel.update_global_params(singleSS)  # make it reflect targetData

    # Calculate evidence under K=1 model
    singleEvBound = singleModel.calc_evidence(SS=singleSS)

    # Verify fresh model preferred over K=1 model
    improveEvBound = freshEvBound - singleEvBound
    if improveEvBound <= 0 or improveEvBound < 0.00001 * abs(singleEvBound):
        msg = "BIRTH terminated. Not better than single component on target data."
        msg += "\n  fresh  | K=%3d | %.7e" % (targetSS.K, freshEvBound)
        msg += "\n  single | K=%3d | %.7e" % (singleSS.K, singleEvBound)
        raise BirthProposalError(msg)

    # Verify fresh model improves over current model
    curLP = curModel.calc_local_params(targetData)
    curSS = curModel.get_global_suff_stats(targetData,
                                           curLP,
                                           doPrecompEntropy=True)
    curEvBound = curModel.calc_evidence(SS=curSS)
    improveEvBound = freshEvBound - curEvBound
    if improveEvBound <= 0 or improveEvBound < 0.00001 * abs(curEvBound):
        msg = "BIRTH terminated. Not better than current model on target data."
        msg += "\n  fresh | K=%3d | %.7e" % (targetSS.K, freshEvBound)
        msg += "\n  cur   | K=%3d | %.7e" % (curSS.K, curEvBound)
        raise BirthProposalError(msg)

    return targetSS
Example #8
0
def clean_up_expanded_suff_stats(targetData,
                                 curModel,
                                 targetSS,
                                 randstate=np.random,
                                 **kwargs):
    ''' Create expanded model combining original and brand-new comps
        and try to identify brand-new comps that are redundant copies of   
        originals and can be removed 
  '''
    import MergeMove
    Korig = curModel.allocModel.K
    origLP = curModel.calc_local_params(targetData)
    expandSS = curModel.get_global_suff_stats(targetData, origLP)
    expandSS.insertComps(targetSS)
    expandModel = curModel.copy()
    expandModel.update_global_params(expandSS)

    expandLP = expandModel.calc_local_params(targetData)
    expandSS = expandModel.get_global_suff_stats(targetData,
                                                 expandLP,
                                                 doPrecompEntropy=True,
                                                 doPrecompMergeEntropy=True)
    Kexpand = expandSS.K

    mPairIDs = MergeMove.preselect_all_merge_candidates(
        expandModel,
        expandSS,
        randstate=np.random,
        preselectroutine=kwargs['cleanuppreselectroutine'],
        mergePerLap=kwargs['cleanupNumMergeTrials'] * (Kexpand - Korig),
        compIDs=range(Korig, Kexpand))

    mPairIDsOrig = [x for x in mPairIDs]

    xModel, xSS, xEv, MTracker = MergeMove.run_many_merge_moves(
        expandModel,
        targetData,
        expandSS,
        nMergeTrials=expandSS.K**2,
        mPairIDs=mPairIDs,
        randstate=randstate,
        **kwargs)

    if kwargs['doVizBirth']:
        viz_birth_proposal_2D(expandModel,
                              xModel,
                              None,
                              None,
                              title1='expanded model',
                              title2='after merge')

    for x in MTracker.acceptedOrigIDs:
        assert x in mPairIDsOrig

    if kwargs['cleanupModifyOrigComps']:
        targetSS = xSS
        targetSS.setELBOFieldsToZero()
        targetSS.setMergeFieldsToZero()
    else:
        # Remove from targetSS all the comps whose merges were accepted
        kBList = [kB for kA, kB in MTracker.acceptedOrigIDs]

        if len(kBList) == targetSS.K:
            msg = 'BIRTH terminated. all new comps redundant with originals.'
            raise BirthProposalError(msg)
        for kB in reversed(sorted(kBList)):
            ktarget = kB - Korig
            if ktarget >= 0:
                targetSS.removeComp(ktarget)
    return targetSS
Example #9
0
    def fit(self, hmodel, DataIterator):
        ''' Run moVB learning algorithm, fit parameters of hmodel to Data,
          traversed one batch at a time from DataIterator

        Returns
        --------
        LP : None type, cannot fit all local params in memory
        Info : dict of run information, with fields
              evBound : final ELBO evidence bound
              status : str message indicating reason for termination
                        {'converged', 'max passes exceeded'}
    
    '''
        # Define how much of data we see at each mini-batch
        nBatch = float(DataIterator.nBatch)
        self.lapFracInc = 1.0 / nBatch
        # Set-up progress-tracking variables
        iterid = -1
        lapFrac = np.maximum(0, self.algParams['startLap'] - 1.0 / nBatch)
        if lapFrac > 0:
            # When restarting an existing run,
            #  need to start with last update for final batch from previous lap
            DataIterator.lapID = int(np.ceil(lapFrac)) - 1
            DataIterator.curLapPos = nBatch - 2
            iterid = int(nBatch * lapFrac) - 1

        # memoLPkeys : keep list of params that should be retained across laps
        self.memoLPkeys = hmodel.allocModel.get_keys_for_memoized_local_params(
        )
        mPairIDs = None

        BirthPlans = list()
        BirthResults = None
        prevBirthResults = None

        SS = None
        isConverged = False
        prevBound = -np.inf
        self.set_start_time_now()
        while DataIterator.has_next_batch():

            # Grab new data
            Dchunk = DataIterator.get_next_batch()
            batchID = DataIterator.batchID

            # Update progress-tracking variables
            iterid += 1
            lapFrac = (iterid + 1) * self.lapFracInc
            self.set_random_seed_at_lap(lapFrac)

            # M step
            if self.algParams['doFullPassBeforeMstep']:
                if SS is not None and lapFrac > 1.0:
                    hmodel.update_global_params(SS)
            else:
                if SS is not None:
                    hmodel.update_global_params(SS)

            # Birth move : track birth info from previous lap
            if self.isFirstBatch(lapFrac):
                if self.hasMove('birth') and self.do_birth_at_lap(lapFrac -
                                                                  1.0):
                    prevBirthResults = BirthResults
                else:
                    prevBirthResults = list()

            # Birth move : create new components
            if self.hasMove('birth') and self.do_birth_at_lap(lapFrac):
                if self.doBirthWithPlannedData(lapFrac):
                    hmodel, SS, BirthResults = self.birth_create_new_comps(
                        hmodel, SS, BirthPlans)

                if self.doBirthWithDataFromCurrentBatch(lapFrac):
                    hmodel, SS, BirthRes = self.birth_create_new_comps(
                        hmodel, SS, Data=Dchunk)
                    BirthResults.extend(BirthRes)

                self.BirthCompIDs = self.birth_get_all_new_comps(BirthResults)
                self.ModifiedCompIDs = self.birth_get_all_modified_comps(
                    BirthResults)
            else:
                BirthResults = list()
                self.BirthCompIDs = list()  # no births = no new components
                self.ModifiedCompIDs = list()

            # Select which components to merge
            if self.hasMove(
                    'merge') and not self.algParams['merge']['doAllPairs']:
                if self.isFirstBatch(lapFrac):
                    if self.hasMove('birth'):
                        compIDs = self.BirthCompIDs
                    else:
                        compIDs = []
                    mPairIDs = MergeMove.preselect_all_merge_candidates(
                        hmodel,
                        SS,
                        randstate=self.PRNG,
                        compIDs=compIDs,
                        **self.algParams['merge'])

            # E step
            if batchID in self.LPmemory:
                oldLPchunk = self.load_batch_local_params_from_memory(
                    batchID, prevBirthResults)
                LPchunk = hmodel.calc_local_params(Dchunk, oldLPchunk,
                                                   **self.algParamsLP)
            else:
                LPchunk = hmodel.calc_local_params(Dchunk, **self.algParamsLP)

            # Collect target data for birth
            if self.hasMove('birth') and self.do_birth_at_lap(lapFrac + 1.0):
                if self.isFirstBatch(lapFrac):
                    BirthPlans = self.birth_select_targets_for_next_lap(
                        hmodel, SS, BirthResults)
                BirthPlans = self.birth_collect_target_subsample(
                    Dchunk, LPchunk, BirthPlans)
            else:
                BirthPlans = list()

            # Suff Stat step
            if batchID in self.SSmemory:
                SSchunk = self.load_batch_suff_stat_from_memory(batchID, SS.K)
                SS -= SSchunk

            SSchunk = hmodel.get_global_suff_stats(
                Dchunk,
                LPchunk,
                doPrecompEntropy=True,
                doPrecompMergeEntropy=self.hasMove('merge'),
                mPairIDs=mPairIDs,
            )

            if SS is None:
                SS = SSchunk.copy()
            else:
                assert SSchunk.K == SS.K
                SS += SSchunk

            # Store batch-specific stats to memory
            if self.algParams['doMemoizeLocalParams']:
                self.save_batch_local_params_to_memory(batchID, LPchunk)
            self.save_batch_suff_stat_to_memory(batchID, SSchunk)

            # Handle removing "extra mass" of fresh components
            #  to make SS have size exactly consistent with entire dataset
            if self.hasMove('birth') and self.isLastBatch(lapFrac):
                hmodel, SS = self.birth_remove_extra_mass(
                    hmodel, SS, BirthResults)

            # ELBO calc
            #self.verify_suff_stats(Dchunk, SS, lapFrac)
            evBound = hmodel.calc_evidence(SS=SS)

            # Merge move!
            if self.hasMove('merge') and isEvenlyDivisibleFloat(lapFrac, 1.):
                hmodel, SS, evBound = self.run_merge_move(
                    hmodel, SS, evBound, mPairIDs)

            # Save and display progress
            self.add_nObs(Dchunk.nObs)
            self.save_state(hmodel, iterid, lapFrac, evBound)
            self.print_state(hmodel, iterid, lapFrac, evBound)
            self.eval_custom_func(hmodel, iterid, lapFrac)

            # Check for Convergence!
            #  evBound will increase monotonically AFTER first lap of the data
            #  verify_evidence will warn if bound isn't increasing monotonically
            if lapFrac > self.algParams['startLap'] + 1.0:
                isConverged = self.verify_evidence(evBound, prevBound, lapFrac)
                if isConverged and lapFrac > 5 and not self.hasMove('birth'):
                    break
            prevBound = evBound

        # Finally, save, print and exit
        if isConverged:
            msg = "converged."
        else:
            msg = "max passes thru data exceeded."
        self.save_state(hmodel, iterid, lapFrac, evBound, doFinal=True)
        self.print_state(hmodel,
                         iterid,
                         lapFrac,
                         evBound,
                         doFinal=True,
                         status=msg)
        return None, self.buildRunInfo(evBound, msg)
  def fit(self, hmodel, DataIterator):
    ''' Run moVB learning algorithm, fit parameters of hmodel to Data,
          traversed one batch at a time from DataIterator

        Returns
        --------
        LP : None type, cannot fit all local params in memory
        Info : dict of run information, with fields
              evBound : final ELBO evidence bound
              status : str message indicating reason for termination
                        {'converged', 'max passes exceeded'}
    
    '''
    # Define how much of data we see at each mini-batch
    nBatch = float(DataIterator.nBatch)
    self.lapFracInc = 1.0/nBatch
    # Set-up progress-tracking variables
    iterid = -1
    lapFrac = np.maximum(0, self.algParams['startLap'] - 1.0/nBatch)
    if lapFrac > 0:
      # When restarting an existing run,
      #  need to start with last update for final batch from previous lap
      DataIterator.lapID = int(np.ceil(lapFrac)) - 1
      DataIterator.curLapPos = nBatch - 2
      iterid = int(nBatch * lapFrac) - 1

    # memoLPkeys : keep list of params that should be retained across laps
    self.memoLPkeys = hmodel.allocModel.get_keys_for_memoized_local_params()
    mPairIDs = None

    BirthPlans = list()
    BirthResults = None
    prevBirthResults = None

    SS = None
    isConverged = False
    prevBound = -np.inf
    self.set_start_time_now()
    while DataIterator.has_next_batch():

      # Grab new data
      Dchunk = DataIterator.get_next_batch()
      batchID = DataIterator.batchID
      
      # Update progress-tracking variables
      iterid += 1
      lapFrac = (iterid + 1) * self.lapFracInc
      self.set_random_seed_at_lap(lapFrac)

      # M step
      if self.algParams['doFullPassBeforeMstep']:
        if SS is not None and lapFrac > 1.0:
          hmodel.update_global_params(SS)
      else:
        if SS is not None:
          hmodel.update_global_params(SS)
      
      # Birth move : track birth info from previous lap
      if self.isFirstBatch(lapFrac):
        if self.hasMove('birth') and self.do_birth_at_lap(lapFrac - 1.0):
          prevBirthResults = BirthResults
        else:
          prevBirthResults = list()

      # Birth move : create new components
      if self.hasMove('birth') and self.do_birth_at_lap(lapFrac):
        if self.doBirthWithPlannedData(lapFrac):
          hmodel, SS, BirthResults = self.birth_create_new_comps(
                                            hmodel, SS, BirthPlans)

        if self.doBirthWithDataFromCurrentBatch(lapFrac):
          hmodel, SS, BirthRes = self.birth_create_new_comps(
                                            hmodel, SS, Data=Dchunk)
          BirthResults.extend(BirthRes)

        self.BirthCompIDs = self.birth_get_all_new_comps(BirthResults)
        self.ModifiedCompIDs = self.birth_get_all_modified_comps(BirthResults)
      else:
        BirthResults = list()
        self.BirthCompIDs = list() # no births = no new components
        self.ModifiedCompIDs = list()

      # Select which components to merge
      if self.hasMove('merge') and not self.algParams['merge']['doAllPairs']:
        if self.isFirstBatch(lapFrac):
          if self.hasMove('birth'):
            compIDs = self.BirthCompIDs
          else:
            compIDs = []
          mPairIDs = MergeMove.preselect_all_merge_candidates(hmodel, SS, 
                           randstate=self.PRNG, compIDs=compIDs,
                           **self.algParams['merge'])

      # E step
      if batchID in self.LPmemory:
        oldLPchunk = self.load_batch_local_params_from_memory(
                                           batchID, prevBirthResults)
        LPchunk = hmodel.calc_local_params(Dchunk, oldLPchunk,
                                           **self.algParamsLP)
      else:
        LPchunk = hmodel.calc_local_params(Dchunk, **self.algParamsLP)

      # Collect target data for birth
      if self.hasMove('birth') and self.do_birth_at_lap(lapFrac+1.0):
        if self.isFirstBatch(lapFrac):
          BirthPlans = self.birth_select_targets_for_next_lap(
                                hmodel, SS, BirthResults)
        BirthPlans = self.birth_collect_target_subsample(
                                Dchunk, LPchunk, BirthPlans)
      else:
        BirthPlans = list()

      # Suff Stat step
      if batchID in self.SSmemory:
        SSchunk = self.load_batch_suff_stat_from_memory(batchID, SS.K)
        SS -= SSchunk

      SSchunk = hmodel.get_global_suff_stats(Dchunk, LPchunk,
                       doPrecompEntropy=True, 
                       doPrecompMergeEntropy=self.hasMove('merge'),
                       mPairIDs=mPairIDs,
                       )

      if SS is None:
        SS = SSchunk.copy()
      else:
        assert SSchunk.K == SS.K
        SS += SSchunk

      # Store batch-specific stats to memory
      if self.algParams['doMemoizeLocalParams']:
        self.save_batch_local_params_to_memory(batchID, LPchunk)          
      self.save_batch_suff_stat_to_memory(batchID, SSchunk)  

      # Handle removing "extra mass" of fresh components
      #  to make SS have size exactly consistent with entire dataset
      if self.hasMove('birth') and self.isLastBatch(lapFrac):
        hmodel, SS = self.birth_remove_extra_mass(hmodel, SS, BirthResults)

      # ELBO calc
      #self.verify_suff_stats(Dchunk, SS, lapFrac)
      evBound = hmodel.calc_evidence(SS=SS)

      # Merge move!      
      if self.hasMove('merge') and isEvenlyDivisibleFloat(lapFrac, 1.):
        hmodel, SS, evBound = self.run_merge_move(hmodel, SS, evBound, mPairIDs)

      # Save and display progress
      self.add_nObs(Dchunk.nObs)
      self.save_state(hmodel, iterid, lapFrac, evBound)
      self.print_state(hmodel, iterid, lapFrac, evBound)
      self.eval_custom_func(hmodel, iterid, lapFrac)

      # Check for Convergence!
      #  evBound will increase monotonically AFTER first lap of the data 
      #  verify_evidence will warn if bound isn't increasing monotonically
      if lapFrac > self.algParams['startLap'] + 1.0:
        isConverged = self.verify_evidence(evBound, prevBound, lapFrac)
        if isConverged and lapFrac > 5 and not self.hasMove('birth'):
          break
      prevBound = evBound

    # Finally, save, print and exit
    if isConverged:
      msg = "converged."
    else:
      msg = "max passes thru data exceeded."
    self.save_state(hmodel, iterid, lapFrac, evBound, doFinal=True) 
    self.print_state(hmodel, iterid, lapFrac,evBound,doFinal=True,status=msg)
    return None, self.buildRunInfo(evBound, msg)
Example #11
0
  def run_merge_move(self, hmodel, Data, SS, LP, evBound):
    ''' Run merge move on hmodel
    ''' 
    import MergeMove
    excludeList = list()
    excludePairs = defaultdict(lambda:set())    
    nMergeAttempts = self.algParams['merge']['mergePerLap']
    trialID = 0
    while trialID < nMergeAttempts:

      # Synchronize contents of the excludeList and excludePairs
      # So that comp excluded in excludeList (due to accepted merge)
      #  is automatically contained in the set of excluded pairs 
      for kx in excludeList:
        for kk in excludePairs:
          excludePairs[kk].add(kx)
          excludePairs[kx].add(kk)

      for kk in excludePairs:
        if len(excludePairs[kk]) > hmodel.obsModel.K - 2:
          if kk not in excludeList:
            excludeList.append(kk)

      if len(excludeList) > hmodel.obsModel.K - 2:
        break # when we don't have any more comps to merge
        
      if len(self.BirthLog) > 0:
        kA = self.BirthLog.pop()
        if kA in excludeList:
          continue
      else:
        kA = None

      oldEv = hmodel.calc_evidence(SS=SS)
      hmodel, SS, evBound, MoveInfo = MergeMove.run_merge_move(
                 hmodel, Data, SS, evBound, kA=kA, randstate=self.PRNG,
                 excludeList=excludeList, excludePairs=excludePairs,
                  **self.algParams['merge'])
      newEv = hmodel.calc_evidence(SS=SS)
      
      trialID += 1
      self.print_msg(MoveInfo['msg'])
      if 'kA' in MoveInfo and 'kB' in MoveInfo:
        kA = MoveInfo['kA']
        kB = MoveInfo['kB']
        excludePairs[kA].add(kB)
        excludePairs[kB].add(kA)

      if MoveInfo['didAccept']:
        assert newEv > oldEv
        kA = MoveInfo['kA']
        kB = MoveInfo['kB']
        # Adjust excludeList since components kB+1, kB+2, ... K
        #  have been shifted down by one due to removal of kB
        for kk in range(len(excludeList)):
          if excludeList[kk] > kB:
            excludeList[kk] -= 1

        # Exclude new merged component kA from future attempts        
        #  since precomputed entropy terms involving kA aren't good
        excludeList.append(kA)

        # Adjust excluded pairs to remove kB and shift down kB+1, ... K
        newExcludePairs = defaultdict(lambda:set())
        for kk in excludePairs.keys():
          ksarr = np.asarray(list(excludePairs[kk]))
          ksarr[ksarr > kB] -= 1
          if kk > kB:
            newExcludePairs[kk-1] = set(ksarr)
          elif kk < kB:
            newExcludePairs[kk] = set(ksarr)
        excludePairs = newExcludePairs

        # Update LP to reflect this merge!
        LPkeys = LP.keys()
        keepLPkeys = hmodel.allocModel.get_keys_for_memoized_local_params()

        for key in LPkeys:
          if key in keepLPkeys:
            LP[key][:, kA] = LP[key][:, kA] + LP[key][:, kB]
            LP[key] = np.delete(LP[key], kB, axis=1)
    return hmodel, SS, LP, evBound