def cleanup_mergenewcompsintoexisting(Data, expandModel, xSS, xLP,
                                            Korig=0, **kwargs):
  import MergeMove

  Kexpand = xSS.K
  mPairIDs = MergeMove.preselect_all_merge_candidates(
              expandModel, xSS, randstate=kwargs['randstate'],
              preselectroutine=kwargs['cleanuppreselectroutine'], 
              mergePerLap=kwargs['cleanupNumMergeTrials']*(Kexpand-Korig),
              compIDs=range(Korig, Kexpand))
  mPairIDsOrig = [x for x in mPairIDs]  

  if xLP['K'] != xSS.K:
    # Provided local params are stale, so need to recompute!
    xLP = expandModel.calc_local_params(Data)
  xSS = expandModel.get_global_suff_stats(Data, xLP,
                  doPrecompEntropy=True, doPrecompMergeEntropy=True,
                  mPairIDs=mPairIDs)

  assert 'randstate' in kwargs
  mergexModel, mergexSS, mergexEv, MTracker = MergeMove.run_many_merge_moves(
                               expandModel, Data, xSS,
                               nMergeTrials=xSS.K**2, 
                               mPairIDs=mPairIDs,
                               **kwargs)

  for x in MTracker.acceptedOrigIDs:
    assert x in mPairIDsOrig
  
  targetSS = xSS
  targetSS.setELBOFieldsToZero()
  targetSS.setMergeFieldsToZero()

  return mergexSS, mergexEv
def clean_up_expanded_suff_stats(targetData, curModel, targetSS,
                                  randstate=np.random, **kwargs):
  ''' Create expanded model combining original and brand-new comps
        and try to identify brand-new comps that are redundant copies of   
        originals and can be removed 
  '''
  import MergeMove
  Korig = curModel.allocModel.K
  origLP = curModel.calc_local_params(targetData)
  expandSS = curModel.get_global_suff_stats(targetData, origLP) 
  expandSS.insertComps(targetSS)
  expandModel = curModel.copy()
  expandModel.update_global_params(expandSS)

  expandLP = expandModel.calc_local_params(targetData)
  expandSS = expandModel.get_global_suff_stats(targetData, expandLP,
                  doPrecompEntropy=True, doPrecompMergeEntropy=True)
  Kexpand = expandSS.K

  mPairIDs = MergeMove.preselect_all_merge_candidates(
              expandModel, expandSS, randstate=np.random,
              preselectroutine=kwargs['cleanuppreselectroutine'], 
              mergePerLap=kwargs['cleanupNumMergeTrials']*(Kexpand-Korig),
              compIDs=range(Korig, Kexpand))

  mPairIDsOrig = [x for x in mPairIDs]

  xModel, xSS, xEv, MTracker = MergeMove.run_many_merge_moves(
                               expandModel, targetData, expandSS,
                               nMergeTrials=expandSS.K**2, 
                               mPairIDs=mPairIDs,
                               randstate=randstate, **kwargs)

  if kwargs['doVizBirth']:
    viz_birth_proposal_2D(expandModel, xModel, None, None,
                          title1='expanded model',
                          title2='after merge')

  for x in MTracker.acceptedOrigIDs:
    assert x in mPairIDsOrig
  
  if kwargs['cleanupModifyOrigComps']:
    targetSS = xSS
    targetSS.setELBOFieldsToZero()
    targetSS.setMergeFieldsToZero()
  else:
    # Remove from targetSS all the comps whose merges were accepted
    kBList = [kB for kA,kB in MTracker.acceptedOrigIDs]

    if len(kBList) == targetSS.K:
      msg = 'BIRTH terminated. all new comps redundant with originals.'
      raise BirthProposalError(msg)
    for kB in reversed(sorted(kBList)):
      ktarget = kB - Korig
      if ktarget >= 0:
        targetSS.removeComp(ktarget)
  return targetSS
Beispiel #3
0
def cleanup_mergenewcompsintoexisting(Data,
                                      expandModel,
                                      xSS,
                                      xLP,
                                      Korig=0,
                                      **kwargs):
    import MergeMove

    Kexpand = xSS.K
    mPairIDs = MergeMove.preselect_all_merge_candidates(
        expandModel,
        xSS,
        randstate=kwargs['randstate'],
        preselectroutine=kwargs['cleanuppreselectroutine'],
        mergePerLap=kwargs['cleanupNumMergeTrials'] * (Kexpand - Korig),
        compIDs=range(Korig, Kexpand))
    mPairIDsOrig = [x for x in mPairIDs]

    if xLP['K'] != xSS.K:
        # Provided local params are stale, so need to recompute!
        xLP = expandModel.calc_local_params(Data)
    xSS = expandModel.get_global_suff_stats(Data,
                                            xLP,
                                            doPrecompEntropy=True,
                                            doPrecompMergeEntropy=True,
                                            mPairIDs=mPairIDs)

    assert 'randstate' in kwargs
    mergexModel, mergexSS, mergexEv, MTracker = MergeMove.run_many_merge_moves(
        expandModel,
        Data,
        xSS,
        nMergeTrials=xSS.K**2,
        mPairIDs=mPairIDs,
        **kwargs)

    for x in MTracker.acceptedOrigIDs:
        assert x in mPairIDsOrig

    targetSS = xSS
    targetSS.setELBOFieldsToZero()
    targetSS.setMergeFieldsToZero()

    return mergexSS, mergexEv
Beispiel #4
0
def clean_up_expanded_suff_stats(targetData,
                                 curModel,
                                 targetSS,
                                 randstate=np.random,
                                 **kwargs):
    ''' Create expanded model combining original and brand-new comps
        and try to identify brand-new comps that are redundant copies of   
        originals and can be removed 
  '''
    import MergeMove
    Korig = curModel.allocModel.K
    origLP = curModel.calc_local_params(targetData)
    expandSS = curModel.get_global_suff_stats(targetData, origLP)
    expandSS.insertComps(targetSS)
    expandModel = curModel.copy()
    expandModel.update_global_params(expandSS)

    expandLP = expandModel.calc_local_params(targetData)
    expandSS = expandModel.get_global_suff_stats(targetData,
                                                 expandLP,
                                                 doPrecompEntropy=True,
                                                 doPrecompMergeEntropy=True)
    Kexpand = expandSS.K

    mPairIDs = MergeMove.preselect_all_merge_candidates(
        expandModel,
        expandSS,
        randstate=np.random,
        preselectroutine=kwargs['cleanuppreselectroutine'],
        mergePerLap=kwargs['cleanupNumMergeTrials'] * (Kexpand - Korig),
        compIDs=range(Korig, Kexpand))

    mPairIDsOrig = [x for x in mPairIDs]

    xModel, xSS, xEv, MTracker = MergeMove.run_many_merge_moves(
        expandModel,
        targetData,
        expandSS,
        nMergeTrials=expandSS.K**2,
        mPairIDs=mPairIDs,
        randstate=randstate,
        **kwargs)

    if kwargs['doVizBirth']:
        viz_birth_proposal_2D(expandModel,
                              xModel,
                              None,
                              None,
                              title1='expanded model',
                              title2='after merge')

    for x in MTracker.acceptedOrigIDs:
        assert x in mPairIDsOrig

    if kwargs['cleanupModifyOrigComps']:
        targetSS = xSS
        targetSS.setELBOFieldsToZero()
        targetSS.setMergeFieldsToZero()
    else:
        # Remove from targetSS all the comps whose merges were accepted
        kBList = [kB for kA, kB in MTracker.acceptedOrigIDs]

        if len(kBList) == targetSS.K:
            msg = 'BIRTH terminated. all new comps redundant with originals.'
            raise BirthProposalError(msg)
        for kB in reversed(sorted(kBList)):
            ktarget = kB - Korig
            if ktarget >= 0:
                targetSS.removeComp(ktarget)
    return targetSS
Beispiel #5
0
    def fit(self, hmodel, DataIterator):
        ''' Run moVB learning algorithm, fit parameters of hmodel to Data,
          traversed one batch at a time from DataIterator

        Returns
        --------
        LP : None type, cannot fit all local params in memory
        Info : dict of run information, with fields
              evBound : final ELBO evidence bound
              status : str message indicating reason for termination
                        {'converged', 'max passes exceeded'}
    
    '''
        # Define how much of data we see at each mini-batch
        nBatch = float(DataIterator.nBatch)
        self.lapFracInc = 1.0 / nBatch
        # Set-up progress-tracking variables
        iterid = -1
        lapFrac = np.maximum(0, self.algParams['startLap'] - 1.0 / nBatch)
        if lapFrac > 0:
            # When restarting an existing run,
            #  need to start with last update for final batch from previous lap
            DataIterator.lapID = int(np.ceil(lapFrac)) - 1
            DataIterator.curLapPos = nBatch - 2
            iterid = int(nBatch * lapFrac) - 1

        # memoLPkeys : keep list of params that should be retained across laps
        self.memoLPkeys = hmodel.allocModel.get_keys_for_memoized_local_params(
        )
        mPairIDs = None

        BirthPlans = list()
        BirthResults = None
        prevBirthResults = None

        SS = None
        isConverged = False
        prevBound = -np.inf
        self.set_start_time_now()
        while DataIterator.has_next_batch():

            # Grab new data
            Dchunk = DataIterator.get_next_batch()
            batchID = DataIterator.batchID

            # Update progress-tracking variables
            iterid += 1
            lapFrac = (iterid + 1) * self.lapFracInc
            self.set_random_seed_at_lap(lapFrac)

            # M step
            if self.algParams['doFullPassBeforeMstep']:
                if SS is not None and lapFrac > 1.0:
                    hmodel.update_global_params(SS)
            else:
                if SS is not None:
                    hmodel.update_global_params(SS)

            # Birth move : track birth info from previous lap
            if self.isFirstBatch(lapFrac):
                if self.hasMove('birth') and self.do_birth_at_lap(lapFrac -
                                                                  1.0):
                    prevBirthResults = BirthResults
                else:
                    prevBirthResults = list()

            # Birth move : create new components
            if self.hasMove('birth') and self.do_birth_at_lap(lapFrac):
                if self.doBirthWithPlannedData(lapFrac):
                    hmodel, SS, BirthResults = self.birth_create_new_comps(
                        hmodel, SS, BirthPlans)

                if self.doBirthWithDataFromCurrentBatch(lapFrac):
                    hmodel, SS, BirthRes = self.birth_create_new_comps(
                        hmodel, SS, Data=Dchunk)
                    BirthResults.extend(BirthRes)

                self.BirthCompIDs = self.birth_get_all_new_comps(BirthResults)
                self.ModifiedCompIDs = self.birth_get_all_modified_comps(
                    BirthResults)
            else:
                BirthResults = list()
                self.BirthCompIDs = list()  # no births = no new components
                self.ModifiedCompIDs = list()

            # Select which components to merge
            if self.hasMove(
                    'merge') and not self.algParams['merge']['doAllPairs']:
                if self.isFirstBatch(lapFrac):
                    if self.hasMove('birth'):
                        compIDs = self.BirthCompIDs
                    else:
                        compIDs = []
                    mPairIDs = MergeMove.preselect_all_merge_candidates(
                        hmodel,
                        SS,
                        randstate=self.PRNG,
                        compIDs=compIDs,
                        **self.algParams['merge'])

            # E step
            if batchID in self.LPmemory:
                oldLPchunk = self.load_batch_local_params_from_memory(
                    batchID, prevBirthResults)
                LPchunk = hmodel.calc_local_params(Dchunk, oldLPchunk,
                                                   **self.algParamsLP)
            else:
                LPchunk = hmodel.calc_local_params(Dchunk, **self.algParamsLP)

            # Collect target data for birth
            if self.hasMove('birth') and self.do_birth_at_lap(lapFrac + 1.0):
                if self.isFirstBatch(lapFrac):
                    BirthPlans = self.birth_select_targets_for_next_lap(
                        hmodel, SS, BirthResults)
                BirthPlans = self.birth_collect_target_subsample(
                    Dchunk, LPchunk, BirthPlans)
            else:
                BirthPlans = list()

            # Suff Stat step
            if batchID in self.SSmemory:
                SSchunk = self.load_batch_suff_stat_from_memory(batchID, SS.K)
                SS -= SSchunk

            SSchunk = hmodel.get_global_suff_stats(
                Dchunk,
                LPchunk,
                doPrecompEntropy=True,
                doPrecompMergeEntropy=self.hasMove('merge'),
                mPairIDs=mPairIDs,
            )

            if SS is None:
                SS = SSchunk.copy()
            else:
                assert SSchunk.K == SS.K
                SS += SSchunk

            # Store batch-specific stats to memory
            if self.algParams['doMemoizeLocalParams']:
                self.save_batch_local_params_to_memory(batchID, LPchunk)
            self.save_batch_suff_stat_to_memory(batchID, SSchunk)

            # Handle removing "extra mass" of fresh components
            #  to make SS have size exactly consistent with entire dataset
            if self.hasMove('birth') and self.isLastBatch(lapFrac):
                hmodel, SS = self.birth_remove_extra_mass(
                    hmodel, SS, BirthResults)

            # ELBO calc
            #self.verify_suff_stats(Dchunk, SS, lapFrac)
            evBound = hmodel.calc_evidence(SS=SS)

            # Merge move!
            if self.hasMove('merge') and isEvenlyDivisibleFloat(lapFrac, 1.):
                hmodel, SS, evBound = self.run_merge_move(
                    hmodel, SS, evBound, mPairIDs)

            # Save and display progress
            self.add_nObs(Dchunk.nObs)
            self.save_state(hmodel, iterid, lapFrac, evBound)
            self.print_state(hmodel, iterid, lapFrac, evBound)
            self.eval_custom_func(hmodel, iterid, lapFrac)

            # Check for Convergence!
            #  evBound will increase monotonically AFTER first lap of the data
            #  verify_evidence will warn if bound isn't increasing monotonically
            if lapFrac > self.algParams['startLap'] + 1.0:
                isConverged = self.verify_evidence(evBound, prevBound, lapFrac)
                if isConverged and lapFrac > 5 and not self.hasMove('birth'):
                    break
            prevBound = evBound

        # Finally, save, print and exit
        if isConverged:
            msg = "converged."
        else:
            msg = "max passes thru data exceeded."
        self.save_state(hmodel, iterid, lapFrac, evBound, doFinal=True)
        self.print_state(hmodel,
                         iterid,
                         lapFrac,
                         evBound,
                         doFinal=True,
                         status=msg)
        return None, self.buildRunInfo(evBound, msg)
  def fit(self, hmodel, DataIterator):
    ''' Run moVB learning algorithm, fit parameters of hmodel to Data,
          traversed one batch at a time from DataIterator

        Returns
        --------
        LP : None type, cannot fit all local params in memory
        Info : dict of run information, with fields
              evBound : final ELBO evidence bound
              status : str message indicating reason for termination
                        {'converged', 'max passes exceeded'}
    
    '''
    # Define how much of data we see at each mini-batch
    nBatch = float(DataIterator.nBatch)
    self.lapFracInc = 1.0/nBatch
    # Set-up progress-tracking variables
    iterid = -1
    lapFrac = np.maximum(0, self.algParams['startLap'] - 1.0/nBatch)
    if lapFrac > 0:
      # When restarting an existing run,
      #  need to start with last update for final batch from previous lap
      DataIterator.lapID = int(np.ceil(lapFrac)) - 1
      DataIterator.curLapPos = nBatch - 2
      iterid = int(nBatch * lapFrac) - 1

    # memoLPkeys : keep list of params that should be retained across laps
    self.memoLPkeys = hmodel.allocModel.get_keys_for_memoized_local_params()
    mPairIDs = None

    BirthPlans = list()
    BirthResults = None
    prevBirthResults = None

    SS = None
    isConverged = False
    prevBound = -np.inf
    self.set_start_time_now()
    while DataIterator.has_next_batch():

      # Grab new data
      Dchunk = DataIterator.get_next_batch()
      batchID = DataIterator.batchID
      
      # Update progress-tracking variables
      iterid += 1
      lapFrac = (iterid + 1) * self.lapFracInc
      self.set_random_seed_at_lap(lapFrac)

      # M step
      if self.algParams['doFullPassBeforeMstep']:
        if SS is not None and lapFrac > 1.0:
          hmodel.update_global_params(SS)
      else:
        if SS is not None:
          hmodel.update_global_params(SS)
      
      # Birth move : track birth info from previous lap
      if self.isFirstBatch(lapFrac):
        if self.hasMove('birth') and self.do_birth_at_lap(lapFrac - 1.0):
          prevBirthResults = BirthResults
        else:
          prevBirthResults = list()

      # Birth move : create new components
      if self.hasMove('birth') and self.do_birth_at_lap(lapFrac):
        if self.doBirthWithPlannedData(lapFrac):
          hmodel, SS, BirthResults = self.birth_create_new_comps(
                                            hmodel, SS, BirthPlans)

        if self.doBirthWithDataFromCurrentBatch(lapFrac):
          hmodel, SS, BirthRes = self.birth_create_new_comps(
                                            hmodel, SS, Data=Dchunk)
          BirthResults.extend(BirthRes)

        self.BirthCompIDs = self.birth_get_all_new_comps(BirthResults)
        self.ModifiedCompIDs = self.birth_get_all_modified_comps(BirthResults)
      else:
        BirthResults = list()
        self.BirthCompIDs = list() # no births = no new components
        self.ModifiedCompIDs = list()

      # Select which components to merge
      if self.hasMove('merge') and not self.algParams['merge']['doAllPairs']:
        if self.isFirstBatch(lapFrac):
          if self.hasMove('birth'):
            compIDs = self.BirthCompIDs
          else:
            compIDs = []
          mPairIDs = MergeMove.preselect_all_merge_candidates(hmodel, SS, 
                           randstate=self.PRNG, compIDs=compIDs,
                           **self.algParams['merge'])

      # E step
      if batchID in self.LPmemory:
        oldLPchunk = self.load_batch_local_params_from_memory(
                                           batchID, prevBirthResults)
        LPchunk = hmodel.calc_local_params(Dchunk, oldLPchunk,
                                           **self.algParamsLP)
      else:
        LPchunk = hmodel.calc_local_params(Dchunk, **self.algParamsLP)

      # Collect target data for birth
      if self.hasMove('birth') and self.do_birth_at_lap(lapFrac+1.0):
        if self.isFirstBatch(lapFrac):
          BirthPlans = self.birth_select_targets_for_next_lap(
                                hmodel, SS, BirthResults)
        BirthPlans = self.birth_collect_target_subsample(
                                Dchunk, LPchunk, BirthPlans)
      else:
        BirthPlans = list()

      # Suff Stat step
      if batchID in self.SSmemory:
        SSchunk = self.load_batch_suff_stat_from_memory(batchID, SS.K)
        SS -= SSchunk

      SSchunk = hmodel.get_global_suff_stats(Dchunk, LPchunk,
                       doPrecompEntropy=True, 
                       doPrecompMergeEntropy=self.hasMove('merge'),
                       mPairIDs=mPairIDs,
                       )

      if SS is None:
        SS = SSchunk.copy()
      else:
        assert SSchunk.K == SS.K
        SS += SSchunk

      # Store batch-specific stats to memory
      if self.algParams['doMemoizeLocalParams']:
        self.save_batch_local_params_to_memory(batchID, LPchunk)          
      self.save_batch_suff_stat_to_memory(batchID, SSchunk)  

      # Handle removing "extra mass" of fresh components
      #  to make SS have size exactly consistent with entire dataset
      if self.hasMove('birth') and self.isLastBatch(lapFrac):
        hmodel, SS = self.birth_remove_extra_mass(hmodel, SS, BirthResults)

      # ELBO calc
      #self.verify_suff_stats(Dchunk, SS, lapFrac)
      evBound = hmodel.calc_evidence(SS=SS)

      # Merge move!      
      if self.hasMove('merge') and isEvenlyDivisibleFloat(lapFrac, 1.):
        hmodel, SS, evBound = self.run_merge_move(hmodel, SS, evBound, mPairIDs)

      # Save and display progress
      self.add_nObs(Dchunk.nObs)
      self.save_state(hmodel, iterid, lapFrac, evBound)
      self.print_state(hmodel, iterid, lapFrac, evBound)
      self.eval_custom_func(hmodel, iterid, lapFrac)

      # Check for Convergence!
      #  evBound will increase monotonically AFTER first lap of the data 
      #  verify_evidence will warn if bound isn't increasing monotonically
      if lapFrac > self.algParams['startLap'] + 1.0:
        isConverged = self.verify_evidence(evBound, prevBound, lapFrac)
        if isConverged and lapFrac > 5 and not self.hasMove('birth'):
          break
      prevBound = evBound

    # Finally, save, print and exit
    if isConverged:
      msg = "converged."
    else:
      msg = "max passes thru data exceeded."
    self.save_state(hmodel, iterid, lapFrac, evBound, doFinal=True) 
    self.print_state(hmodel, iterid, lapFrac,evBound,doFinal=True,status=msg)
    return None, self.buildRunInfo(evBound, msg)