Exemplo n.º 1
0
def sampleAndCalcMI(wtSer,
                    nSamp,
                    nIter,
                    sampler,
                    testSampParams,
                    genSampParams,
                    mutator,
                    mutatorParams,
                    lnLik,
                    drawGraph=False,
                    verbose=False):
    testSamps = sampler(nSamp, **testSampParams)
    guess = sampler(nSamp, **genSampParams)
    lnLikParams = {'samps2V': testSamps.values, 'wtSerV': wtSer}
    cleanSamps = genMetropolisSamples(nSamp,
                                      nIter,
                                      guess,
                                      lnLik,
                                      lnLikParams,
                                      mutator,
                                      mutatorParams,
                                      verbose=verbose)
    cleanV = np.concatenate(cleanSamps)
    expandedTestV = np.concatenate([testSamps.values] * len(cleanSamps))

    if drawGraph:
        testBins, nBins = whichBin(expandedTestV)
        rsltBins = whichBin(cleanV)[0]
        hM, xEdges, yEdges = np.histogram2d(testBins, rsltBins, bins=64)
        plt.imshow(np.log(hM + 1))
        plt.show()

    return mutualInfo(cleanV, expandedTestV)
Exemplo n.º 2
0
    def age_transition(self, new_all_samples):
        # each fixed key must get updated according to its own rule
        # - gender, birthorder, etc. stay fixed
        # the outer cohort is all samples at the new age with the same fixed keys
        print('new_all_samples unique entries: ', new_all_samples.index.unique())
        all_col_l = self.open_l + self.advancing_l
        which_bin = createBinner(all_col_l, range_d=self.range_d)
        wt_ser = createWeightSer(all_col_l, range_d=self.range_d)
        samples_subset = select_subset(new_all_samples, self.fixed_d)
        print('samples_subset unique entries: ', new_all_samples.index.unique())
        new_outer_cohort = self.samp_gen(samples_subset)        
        new_age = self.age + 1
        print('------------------')
        print('new outer cohort unique entries: ', new_outer_cohort.index.unique())
        print('starting %s -> %s' % (self.age, new_age))
        print('------------------')
        nSamp = len(self.inner_cohort)
        nIter = 1000
        stepsizes = np.empty([nSamp])
        stepsizes.fill(0.005)
        testSampParams = {'df': self.inner_cohort}
        genSampParams = {'df': new_outer_cohort}
        binnerParams = {}
        #mutator = FreshDrawMutator()
        mutator = MSTMutator(new_outer_cohort)
        mutator.plot_tree()
        mutatorParams = {'nsteps': 2, 'df': new_outer_cohort}
    
        rslt = minimize(minimizeMe, wt_ser.values.copy(),
                        (nSamp, nIter, all_col_l,
                         self.samp_gen,
                         testSampParams, genSampParams,
                         which_bin, binnerParams,
                         mutator, mutatorParams),
                        method='L-BFGS-B',
                        bounds=[(0.25*v, 4.0*v) for v in wt_ser.values])
        print('------------------')
        print('Optimization result:')
        print(rslt)
        print('------------------')
 
        bestWtSer = createWeightSer(all_col_l, {}, rslt.x)
        lnLikParams = {'samps2V': self.inner_cohort, 'wtSerV': bestWtSer}
        cleanSamps = genMetropolisSamples(nSamp, nIter, self.samp_gen(**genSampParams), 
                                          lnLik, lnLikParams,
                                          mutator, mutatorParams, verbose=True)
        if isinstance(cleanSamps[0], pd.DataFrame):
            newCleanV = pd.concat(cleanSamps)
        else:
            newCleanV = np.concatenate(cleanSamps)
        new_inner_cohort = self.samp_gen(newCleanV)
        
        print('new inner cohort unique entries: ', new_inner_cohort.index.unique())
        print('------------------')

        self.age = new_age
        self.outer_cohort = new_outer_cohort
        self.inner_cohort = new_inner_cohort
Exemplo n.º 3
0
 def gen_samples_using(self, target_cohort, pool_cohort,
                       mutator, mutator_params,
                        wt_ser, niter=1000):
         ln_lik_params = {'samps2V': target_cohort, 'wtSerV': wt_ser}
         gen_samp_params = {'df': pool_cohort}
         
         nsamp = len(target_cohort)
         cleanSamps = genMetropolisSamples(nsamp, niter, self.samp_gen(**gen_samp_params), 
                                           lnLik, ln_lik_params,
                                           mutator, mutator_params, verbose=True)
         if isinstance(cleanSamps[0], pd.DataFrame):
             newCleanV = pd.concat(cleanSamps)
         else:
             newCleanV = np.concatenate(cleanSamps)
         return self.samp_gen(newCleanV)
Exemplo n.º 4
0
def generateSamples(oldSamps,
                    wtVec,
                    workingCols,
                    nIter,
                    sampler,
                    genSampParams,
                    srcDF,
                    mutator,
                    mutatorParams,
                    verbose=False):

    # get the right index order but no extra entries
    wtSer = pd.Series(
        {
            key: val
            for key, val in zip(workingCols, list(range(len(workingCols))))
        },
        index=srcDF.columns)
    # wtSer = pd.Series({'YEAR': wtVec[0],
    #                   'FPL': wtVec[1],
    #                   'SC_AGE_YEARS': wtVec[2],
    #                   'K4Q32X01': wtVec[3],
    #                   'K7Q30': wtVec[4],
    #                   'K7Q31': wtVec[5],
    #                   'AGEPOS4': wtVec[6]},
    #                  index=srcDF.columns)
    dropL = [col for col in wtSer.index if col not in workingCols]
    wtSer = wtSer.drop(labels=dropL)

    nSamp = len(oldSamps)
    guess = sampler(nSamp, **genSampParams)
    lnLikParams = {'samps2V': oldSamps, 'wtSerV': wtSer}
    cleanSamps = genMetropolisSamples(nSamp,
                                      nIter,
                                      guess,
                                      lnLik,
                                      lnLikParams,
                                      mutator,
                                      mutatorParams,
                                      verbose=verbose,
                                      mutationsPerSamp=2,
                                      burninMutations=4)
    cleanV = np.concatenate(cleanSamps)

    return cleanV
def sampleAndCalcMI(wtSer,
                    nSamp,
                    nIter,
                    sampler,
                    testSampParams,
                    genSampParams,
                    binner,
                    binnerParams,
                    mutator,
                    mutatorParams,
                    drawGraph=False,
                    verbose=False):
    tdf = testSampParams['df']
    n_samp = testSampParams['n_samp']
    assert len(tdf) < n_samp, 'test df is too big'
    gp1 = tdf.drop(columns='FWC')
    gp2 = mkSamps(tdf, n_samp - len(tdf))
    testSamps = pd.concat([gp1, gp2], axis=0)
    #testSamps = sampler(**testSampParams)
    guess = sampler(**genSampParams)
    lnLikParams = {'samps2V': testSamps, 'wtSerV': wtSer}
    cleanSamps = genMetropolisSamples(nSamp,
                                      nIter,
                                      guess,
                                      lnLik,
                                      lnLikParams,
                                      mutator,
                                      mutatorParams,
                                      verbose=verbose)
    if isinstance(cleanSamps[0], pd.DataFrame):
        cleanV = pd.concat(cleanSamps)
        expandedTestV = pd.concat([testSamps] * len(cleanSamps))
    else:
        cleanV = np.concatenate(cleanSamps)
        expandedTestV = np.concatenate([testSamps.values] * len(cleanSamps))

    if drawGraph:
        testBins, nBins = whichBin(expandedTestV)
        rsltBins = whichBin(cleanV)[0]
        hM, xEdges, yEdges = np.histogram2d(testBins, rsltBins, bins=64)
        plt.imshow(np.log(hM + 1))
        plt.show()

    return mutualInfo(cleanV, expandedTestV, binner, binnerParams=binnerParams)