def residual(A, zeta, realFlag=False, plotFlag=False):
    """Returns Euclidean norm of subhalo and core mass functions, where only subhalos and cores in 1:1 match are considered."""
    # residual plots
    print len(sh['subhalo_mass'][sh_mask][fmask])
    print len(cc_filtered[m_evolved_col(A, zeta)][idx[:, 0][fmask]])
    # realmask = cc_filtered['fof_halo_tag'][idx[:,0][fmask]]>=0

    if realFlag:
        realmask = np.invert(cc_filtered['wasInFragment'][idx[:, 0][fmask]])
    else:
        realmask = np.full(np.sum(fmask), True, dtype=bool)

    shmf = sh['subhalo_mass'][sh_mask][fmask][realmask] / sh['M'][sh_mask][
        fmask][realmask]
    r = (-3, 0)
    r_res = (-3, -2)

    sh_xarr, sh_cnts = hist(np.log10(shmf),
                            bins=100,
                            normed=True,
                            normBinsize=True,
                            normCnts=False,
                            normLogCnts=True,
                            normScalar=nHalo,
                            plotFlag=plotFlag,
                            label='subhalos',
                            alpha=1,
                            range=r)
    res_mask = (r_res[0] <= sh_xarr) & (sh_xarr <= r_res[1])

    cmf = cc_filtered[m_evolved_col(A, zeta)][
        idx[:, 0][fmask]][realmask] / cc_filtered['M'][idx[:,
                                                           0][fmask]][realmask]
    cores_xarr, cores_cnts = hist(np.log10(cmf),
                                  bins=100,
                                  normed=True,
                                  plotFlag=plotFlag,
                                  label='cores',
                                  alpha=1,
                                  range=r,
                                  normScalar=nHalo,
                                  normCnts=False,
                                  normBinsize=True,
                                  normLogCnts=True)

    return np.linalg.norm((sh_cnts - cores_cnts)[res_mask])
def shmass(A, zeta, realFlag=False, plotFlag=False, mergedCoreTagFlag=False):
    """Returns Euclidean norm of subhalo and core mass functions, where only subhalos and cores in 1:1 match are considered."""
    # residual plots
    # print len(sh['subhalo_mass'][sh_mask][fmask])
    # print len(cc_filtered[m_evolved_col(A, zeta)][idx[:,0][fmask]])
    # realmask = cc_filtered['fof_halo_tag'][idx[:,0][fmask]]>=0

    realmask = np.full(np.sum(fmask), True, dtype=bool)
    if realFlag:
        realmask = np.invert(cc_filtered['wasInFragment'][idx[:, 0][fmask]])
    if mergedCoreTagFlag:
        realmask = realmask & (cc_filtered['mergedCoreTag'][idx[:, 0][fmask]]
                               == 0)

    shmf = sh['subhalo_mass'][sh_mask][fmask][realmask]
    r = (9, 15)

    sh_xarr, sh_cnts = hist(np.log10(shmf),
                            bins=100,
                            normed=True,
                            normBinsize=False,
                            normCnts=True,
                            normLogCnts=True,
                            normScalar=1,
                            plotFlag=plotFlag,
                            label='subhalos',
                            alpha=1,
                            range=r)

    cmf = cc_filtered[m_evolved_col(A, zeta)][idx[:, 0][fmask]][realmask]
    cores_xarr, cores_cnts = hist(np.log10(cmf),
                                  bins=100,
                                  normed=True,
                                  plotFlag=plotFlag,
                                  label='cores',
                                  alpha=1,
                                  range=r,
                                  normScalar=1,
                                  normCnts=True,
                                  normBinsize=False,
                                  normLogCnts=True)
def mostMassiveCore(A, zeta, plotFlag):
    iarr, carr, missing = [], [], []
    # cnt = 0
    for i in tqdm(range(len(distance_upper_bound))):
        qres = cores_tree.query_ball_point(sh_arr[i],
                                           r=distance_upper_bound[i])
        # qres2 = []
        # for qidx in qres:
        #     if SHMLM.dist(sh['X'][sh_mask][i], sh['Y'][sh_mask][i], sh['Z'][sh_mask][i], cc_filtered['x'][qidx], cc_filtered['y'][qidx], cc_filtered['z'][qidx]) != 0:
        #         qres2.append(qidx)
        #     else:
        #         cnt += 1
        # qres = qres2
        if len(qres) > 0:
            idxmax = qres[np.argmax(cc_filtered[m_evolved_col(A, zeta)][qres])]
            iarr.append(i)
            carr.append(idxmax)
        else:
            missing.append(i)
    percentexists = len(iarr) / np.sum(sh_mask) * 100
    print '{}% of masked subhalos have at least 1 core within their search radius.'.format(
        percentexists)
    shmf = sh['subhalo_mass'][sh_mask][iarr]
    r = (9, 15)

    sh_xarr, sh_cnts = hist(np.log10(shmf),
                            bins=100,
                            normed=True,
                            normBinsize=True,
                            normCnts=False,
                            normLogCnts=True,
                            normScalar=nHalo,
                            plotFlag=plotFlag,
                            label='subhalos',
                            alpha=1,
                            range=r)

    cmf = cc_filtered[m_evolved_col(A, zeta)][carr]
    cores_xarr, cores_cnts = hist(np.log10(cmf),
                                  bins=100,
                                  normed=True,
                                  plotFlag=plotFlag,
                                  label='cores',
                                  alpha=1,
                                  range=r,
                                  normScalar=nHalo,
                                  normCnts=False,
                                  normBinsize=True,
                                  normLogCnts=True)
    # print "Cnt: "+str(cnt)

    plt.figure()
    plt.hist((sh['subhalo_mass'][sh_mask][iarr] -
              cc_filtered[m_evolved_col(A, zeta)][carr]) /
             sh['subhalo_mass'][sh_mask][iarr],
             range=(-5, 5),
             bins=100)  #, alpha=.5)
    plt.axvline(x=0, c='k')

    plt.figure()
    shmassmask = sh['subhalo_mass'][sh_mask][iarr] >= (SHMLM.PARTICLES100MASS *
                                                       10)
    plt.hist((sh['subhalo_mass'][sh_mask][iarr][shmassmask] -
              cc_filtered[m_evolved_col(A, zeta)][carr][shmassmask]) /
             sh['subhalo_mass'][sh_mask][iarr][shmassmask],
             range=(-5, 5),
             bins=100)  #, alpha=.5)
    plt.axvline(x=0, c='k')
def subhalo_core_mass_plots(A,
                            zeta,
                            M1,
                            M2,
                            plotFlag,
                            residualRange,
                            normLogCntsFlag=True,
                            logmOverMFlag=False,
                            phaseMergedFlag=False,
                            normCnts=False,
                            MdepPlots=False,
                            vlinepart=True,
                            normScalar=False,
                            normBinsize=False,
                            shPlot=True,
                            part100cutFlag=False,
                            bins=50,
                            ratio=False,
                            plotHMFlag=True):
    """Plots subhalo and core mass histogram for all satellite subhalos and cores with at least 100 particles.
    Returns their residual (L2 norm) in `residualRange`."""
    if MdepPlots:
        fig, ((ax1, ax2), (ax3,
                           ax4)) = plt.subplots(2,
                                                2,
                                                sharex=True,
                                                sharey=True,
                                                gridspec_kw={
                                                    'hspace': 0,
                                                    'wspace': 0
                                                },
                                                figsize=[4.8 * 2, 4.8 * 1.5],
                                                dpi=150)
        #         fig.suptitle('z={}'.format(z))

        for ax, lM in zip([ax1, ax2, ax3, ax4], Mlist):
            M1, M2 = 10**lM, 10**(lM + massbinsize)
            # ax.set_title( '{} $\leq$ log(M/$h^{{-1}}M_\odot$)$\leq$ {}'.format(np.log10(M1),np.log10(M2)) )

            nH_core_SV, cc_filtered_SV = cc_filtered_dict_SV[lM]

            nH_core_HM, cc_filtered_HM = cc_filtered_dict_HM[lM]

            if part100cutFlag:
                cores100mask_SV = cc_filtered_SV[m_evolved_col(
                    A, zeta)] >= SHMLM_SV.PARTICLECUTMASS
                cores100mask_HM = cc_filtered_HM[m_evolved_col(
                    A, zeta)] >= SHMLM_HM.PARTICLECUTMASS
            else:
                cores100mask_SV = np.full_like(
                    cc_filtered_SV[m_evolved_col(A, zeta)] >= 0, True)
                cores100mask_HM = np.full_like(
                    cc_filtered_HM[m_evolved_col(A, zeta)] >= 0, True)

            nH_sh_param, nH_cHM_param, nH_cSV_param = 1, 1, 1
            if logmOverMFlag:
                r = (-3, 0)
                if part100cutFlag:
                    shmf, nH_sh = shmfdict_cut100[lM]
                else:
                    shmf, nH_sh = shmfdict[lM]
                if normScalar:
                    nH_sh_param = nH_core_SV  #nH_sh
                    nH_cHM_param = nH_core_HM
                    nH_cSV_param = nH_core_SV
                print 'nH_sh_param: ', nH_sh_param
                print 'nH_cSV_param: ', nH_cSV_param
                print 'nH_cHM_param: ', nH_cHM_param
                print ''
                if ratio:
                    plotFlag = False
                if shPlot:
                    sh_xarr, sh_cnts = hist(np.log10(shmf),
                                            bins=bins,
                                            normed=True,
                                            normBinsize=normBinsize,
                                            normCnts=normCnts,
                                            normLogCnts=normLogCntsFlag,
                                            normScalar=nH_sh_param,
                                            plotFlag=plotFlag,
                                            label='subhalos SV',
                                            alpha=.7,
                                            range=r,
                                            ax=ax)

                cmf_SV = cc_filtered_SV[m_evolved_col(
                    A, zeta
                )][cores100mask_SV] / cc_filtered_SV['M'][cores100mask_SV]
                cores_xarr_SV, cores_cnts_SV = hist(
                    np.log10(cmf_SV),
                    bins=bins,
                    normed=True,
                    plotFlag=plotFlag,
                    label='cores SV',
                    alpha=.7,
                    range=r,
                    normScalar=nH_cSV_param,
                    normCnts=normCnts,
                    normBinsize=normBinsize,
                    normLogCnts=normLogCntsFlag,
                    ax=ax)

                if plotHMFlag:
                    cmf_HM = cc_filtered_HM[m_evolved_col(
                        A, zeta
                    )][cores100mask_HM] / cc_filtered_HM['M'][cores100mask_HM]
                    cores_xarr_HM, cores_cnts_HM = hist(
                        np.log10(cmf_HM),
                        bins=bins,
                        normed=True,
                        plotFlag=plotFlag,
                        label='cores HM',
                        alpha=1,
                        range=r,
                        normScalar=nH_cHM_param,
                        normCnts=normCnts,
                        normBinsize=normBinsize,
                        normLogCnts=normLogCntsFlag,
                        ax=ax)

                if ratio:
                    #                     assert np.array_equal(sh_xarr, cores_xarr_SV) and np.array_equal(cores_xarr_SV, cores_xarr_HM)
                    isfin_SV = np.isfinite(10**(cores_cnts_SV - sh_cnts))
                    ax.axhline(1, ls='--', alpha=1, zorder=0, c='k')
                    ax.plot(sh_xarr[isfin_SV],
                            (10**(cores_cnts_SV - sh_cnts))[isfin_SV],
                            label='cores/subhalos (SV)')

                    if plotHMFlag:
                        isfin_HM = np.isfinite(10**(cores_cnts_HM - sh_cnts))
                        ax.plot(sh_xarr[isfin_HM],
                                (10**(cores_cnts_HM - sh_cnts))[isfin_HM],
                                label='cores HM')

                    ax.set_xlim(-3.2, 0)
                    ax.set_ylim(0, 3.4)
                else:
                    ax.set_xlim(-3.2, 0)


#                     ax.set_ylim(-2, 1.4)
#                     ax.set_xlabel(r'$\log(m/M)$')
#                     ax.set_ylabel(r'$\log \left[ \mathrm{d}N/\mathrm{d} \log(m/M) \right]$')
            else:
                #r = (9, 15)
                r = (np.log10(SHMLM_HM.PARTICLECUTMASS / 100.), 15)
                if part100cutFlag:
                    shmf, nH = shfdict_cut100[lM]
                else:
                    shmf, nH = shfdict[lM]
                if normScalar:
                    nH_sh_param = nH_core_SV  #nH_sh
                    nH_cHM_param = nH_core_HM
                    nH_cSV_param = nH_core_SV
                print 'nH_sh_param: ', nH_sh_param
                print 'nH_cSV_param: ', nH_cSV_param
                print 'nH_cHM_param: ', nH_cHM_param
                print ''
                if shPlot:
                    sh_xarr, sh_cnts = hist(np.log10(shmf),
                                            bins=bins,
                                            normed=True,
                                            normBinsize=normBinsize,
                                            normCnts=normCnts,
                                            normLogCnts=normLogCntsFlag,
                                            normScalar=nH_sh_param,
                                            plotFlag=plotFlag,
                                            label='subhalos SV',
                                            alpha=1,
                                            range=r,
                                            ax=ax)

                cmf_SV = cc_filtered_SV[m_evolved_col(A,
                                                      zeta)][cores100mask_SV]
                cores_xarr_SV, cores_cnts_SV = hist(
                    np.log10(cmf_SV),
                    bins=bins,
                    normed=True,
                    plotFlag=plotFlag,
                    label='cores SV',
                    alpha=1,
                    range=r,
                    normScalar=nH_cHM_param,
                    normCnts=normCnts,
                    normBinsize=normBinsize,
                    normLogCnts=normLogCntsFlag,
                    ax=ax)
                if plotHMFlag:
                    cmf_HM = cc_filtered_HM[m_evolved_col(
                        A, zeta)][cores100mask_HM]
                    cores_xarr_HM, cores_cnts_HM = hist(
                        np.log10(cmf_HM),
                        bins=bins,
                        normed=True,
                        plotFlag=plotFlag,
                        label='cores HM',
                        alpha=1,
                        range=r,
                        normScalar=nH_cHM_param,
                        normCnts=normCnts,
                        normBinsize=normBinsize,
                        normLogCnts=normLogCntsFlag,
                        ax=ax)

                ax.set_ylim(-5, 3.6)
                ax.set_xlim(7.5, 14.2)
                if vlinepart:
                    ax.axvline(x=np.log10(SHMLM_SV.PARTICLES100MASS),
                               c='r',
                               lw=0.5,
                               label='100 part. SV',
                               ymax=0.90)
                    ax.axvline(x=np.log10(SHMLM_HM.PARTICLES100MASS),
                               c='k',
                               lw=0.5,
                               label='100 part. HM')
                    # ax.axvline(x=np.log10(SHMLM.PARTICLES100MASS/100.*20), c='r', lw=0.5, label='20 part.')

            assert np.array_equal(cores_xarr_SV, sh_xarr)
            ##ax.set_title( '[{},{}]'.format(np.log10(M1),np.log10(M2)) )
            ax.set_title('{} $\leq$ log(M/$h^{{-1}}M_\odot$)$\leq$ {}'.format(
                np.log10(M1), np.log10(M2)),
                         pad=-16,
                         x=0.60)
            ##ax.text(0.5, 0.7, 'sh:{}'.format(np.format_float_scientific(len(shmf))), horizontalalignment='center',verticalalignment='center', transform=ax.transAxes)
            ##ax.text(0.5, 0.6, 'c:{}'.format(np.format_float_scientific(len(cmf))), horizontalalignment='center',verticalalignment='center', transform=ax.transAxes)
            ##ax.text(0.5, 0.5, 'H:{}'.format(np.format_float_scientific(nH)), horizontalalignment='center',verticalalignment='center', transform=ax.transAxes)
        if logmOverMFlag:
            fig.text(0.5, 0.08, r'$\log(m/M)$', ha="center", va="center")
            if ratio:
                fig.text(
                    0.08,
                    0.5,
                    r'ratio $\left[ \mathrm{d}N/\mathrm{d} \log(m/M) \right]$',
                    ha="center",
                    va="center",
                    rotation=90)
            else:
                fig.text(
                    0.08,
                    0.5,
                    r'$\log \left[ \mathrm{d}N/\mathrm{d} \log(m/M) \right]$',
                    ha="center",
                    va="center",
                    rotation=90)
        else:
            fig.text(0.5, 0.08, r'$\log(m)$', ha="center", va="center")
            fig.text(0.08,
                     0.5,
                     r'$\log \left[ \mathrm{d}N/\mathrm{d} \log(m) \right]$',
                     ha="center",
                     va="center",
                     rotation=90)
        ax4.legend(loc=4)
        return fig, ((ax1, ax2), (ax3, ax4))
        '''
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)
        fig.suptitle('z={}'.format(z))

        for ax,lM in zip([ax1, ax2, ax3, ax4],Mlist):
            M1, M2 = 10**lM, 10**(lM+massbinsize)
            # ax.set_title( '{} $\leq$ log(M/$h^{{-1}}M_\odot$)$\leq$ {}'.format(np.log10(M1),np.log10(M2)) )

            cc_filtered = cc_filtered_dict[lM] #generate_cores_kdtree(M1=M1, M2=M2, s1=False, disrupt=None, onlyFiltered=True)
            cores100mask = cc_filtered[m_evolved_col(A, zeta)]>=SHMLM.PARTICLECUTMASS
            if logmOverMFlag:
                normScalarVal = 1
                r = (-3, 0)
                # shmf = sh['subhalo_mass'][sh_mask][sh100mask] / sh['M'][sh_mask][sh100mask]
                shmf, nH = shmfdict[lM]
                sh_xarr, sh_cnts = hist(np.log10(shmf), bins=100, normed=True, normBinsize=normBinsize, normCnts=normCnts, normLogCnts=normLogCntsFlag, normScalar=normScalarVal, plotFlag=plotFlag, label='subhalos', alpha=1, range=r, ax=ax)

                cmf = cc_filtered[m_evolved_col(A, zeta)][cores100mask] / cc_filtered['M'][cores100mask]
                cores_xarr, cores_cnts = hist(np.log10(cmf), bins=100, normed=True, plotFlag=plotFlag, label='cores', alpha=1, range=r, normScalar=normScalarVal, normCnts=normCnts, normBinsize=normBinsize, normLogCnts=normLogCntsFlag, ax=ax)
                ax.set_xlim(-3.2, 0)
                #ax.set_ylim(-0.8, 3)
            else:
                r = (9, 15)
                # shmf = sh['subhalo_mass'][sh_mask][sh100mask] / sh['M'][sh_mask][sh100mask]
                shmf, nH = shfdict[lM]
                normScalarVal = 1
                if normScalar:
                    normScalarVal = nH
                sh_xarr, sh_cnts = hist(np.log10(shmf), bins=100, normed=True, normBinsize=normBinsize, normCnts=normCnts, normLogCnts=normLogCntsFlag, normScalar=normScalarVal, plotFlag=plotFlag, label='subhalos', alpha=1, range=r, ax=ax)

                cmf = cc_filtered[m_evolved_col(A, zeta)][cores100mask]
                cores_xarr, cores_cnts = hist(np.log10(cmf), bins=100, normed=True, plotFlag=plotFlag, label='cores', alpha=1, range=r, normScalar=normScalarVal, normCnts=normCnts, normBinsize=normBinsize, normLogCnts=normLogCntsFlag, ax=ax)
                ax.set_xlim(8.8, 15)
                ax.set_ylim(-0.8, 6)
                if vlinepart:
                    ax.axvline(x=np.log10(SHMLM.PARTICLES100MASS), c='k', lw=0.5, label='100 part.')
                    ax.axvline(x=np.log10(SHMLM.PARTICLES100MASS/100.*20), c='r', lw=0.5, label='20 part.')

            assert np.array_equal(cores_xarr, sh_xarr)
            ##ax.set_title( '[{},{}]'.format(np.log10(M1),np.log10(M2)) )
            ax.set_title( '{} $\leq$ log(M/$h^{{-1}}M_\odot$)$\leq$ {}'.format(np.log10(M1),np.log10(M2)) )
            ##ax.text(0.5, 0.7, 'sh:{}'.format(np.format_float_scientific(len(shmf))), horizontalalignment='center',verticalalignment='center', transform=ax.transAxes)
            ##ax.text(0.5, 0.6, 'c:{}'.format(np.format_float_scientific(len(cmf))), horizontalalignment='center',verticalalignment='center', transform=ax.transAxes)
            ##ax.text(0.5, 0.5, 'H:{}'.format(np.format_float_scientific(nH)), horizontalalignment='center',verticalalignment='center', transform=ax.transAxes)
        if logmOverMFlag:
            fig.text(0.5,0.04, r'$\log(m/M)$', ha="center", va="center")
            fig.text(0.05,0.5, r'$\log \left[ \mathrm{d}N/\mathrm{d} \log(m/M) \right]$', ha="center", va="center", rotation=90)
        else:
            fig.text(0.5,0.04, r'$\log(m)$', ha="center", va="center")
            fig.text(0.05,0.5, r'$\log \left[ \mathrm{d}N/\mathrm{d} \log(m) \right]$', ha="center", va="center", rotation=90)
        ax1.legend()
        plt.subplots_adjust(hspace = 0.4)
        return fig, ((ax1, ax2), (ax3, ax4))
        '''
    else:
        #         sh_mask = (sh['subhalo_tag']!=0)&(M1<=sh['M'])&(sh['M']<=M2)#&(distance_mask) #&(sh['subhalo_mass']>=SHMLM.SUBHALOMASSCUT)
        #         sh100mask = sh['subhalo_mass'][sh_mask]>=SHMLM.PARTICLECUTMASS
        #         cores_tree, cc_filtered, nHalo = generate_cores_kdtree(M1, M2, s1=False, disrupt=None, onlyFiltered=False, giveFragmentsFofMass=False)
        global sh_mask, sh100mask, cores_tree, cc_filtered, nHalo
        cores100mask = cc_filtered[m_evolved_col(
            A, zeta)] >= SHMLM.PARTICLECUTMASS
        if phaseMergedFlag:
            cores100mask = cores100mask & (cc_filtered['phaseSpaceMerged'] !=
                                           1)
        if logmOverMFlag:
            if normScalar:
                normScalarVal = nHalo
            else:
                normScalarVal = 1
            r = (-3, 0)
            shmf = sh['subhalo_mass'][sh_mask][sh100mask] / sh['M'][sh_mask][
                sh100mask]
            sh_xarr, sh_cnts = hist(np.log10(shmf),
                                    bins=100,
                                    normed=True,
                                    normBinsize=normBinsize,
                                    normCnts=normCnts,
                                    normLogCnts=normLogCntsFlag,
                                    normScalar=normScalarVal,
                                    plotFlag=plotFlag,
                                    label='subhalos',
                                    alpha=1,
                                    range=r)

            cmf = cc_filtered[m_evolved_col(
                A, zeta)][cores100mask] / cc_filtered['M'][cores100mask]
            cores_xarr, cores_cnts = hist(np.log10(cmf),
                                          bins=100,
                                          normed=True,
                                          plotFlag=plotFlag,
                                          label='cores',
                                          alpha=1,
                                          range=r,
                                          normScalar=normScalarVal,
                                          normCnts=normCnts,
                                          normBinsize=normBinsize,
                                          normLogCnts=normLogCntsFlag)
        else:
            r = (9, 15)
            shmf = sh['subhalo_mass'][sh_mask][sh100mask]
            sh_xarr, sh_cnts = hist(np.log10(shmf),
                                    bins=100,
                                    normed=True,
                                    normBinsize=False,
                                    normCnts=normCnts,
                                    normLogCnts=normLogCntsFlag,
                                    normScalar=1,
                                    plotFlag=plotFlag,
                                    label='subhalos',
                                    alpha=1,
                                    range=r)

            cmf = cc_filtered[m_evolved_col(A, zeta)][cores100mask]
            cores_xarr, cores_cnts = hist(np.log10(cmf),
                                          bins=100,
                                          normed=True,
                                          plotFlag=plotFlag,
                                          label='cores',
                                          alpha=1,
                                          range=r,
                                          normScalar=1,
                                          normCnts=normCnts,
                                          normBinsize=False,
                                          normLogCnts=normLogCntsFlag)
        assert np.array_equal(cores_xarr, sh_xarr)
        # print sh_cnts, cores_cnts
        res_mask = (residualRange[0] <= sh_xarr) & (sh_xarr <=
                                                    residualRange[1])
        return np.linalg.norm(
            np.true_divide(((sh_cnts - cores_cnts)[res_mask]),
                           sh_cnts[res_mask]))