def fn(mx, ms, sv):
        assert ms > mx
        hp = HiggsPortal(mx, ms, 1, 1)
        sv_1 = hp.annihilation_cross_sections(
            2 * hp.mx * (1 + 0.5 * v_mw**2)
        )["total"] * v_mw * sv_inv_MeV_to_cm3_per_s

        # Find smallest stheta compatible with <sigma v> for the given
        # gsxx_max
        stheta_min = np.sqrt(sv / sv_1) / gsxx_max
        if stheta_min > 0.999:
            return -1e100
        
        stheta_grid = np.geomspace(stheta_min, 0.999, 20)
        constr_mins = np.full_like(stheta_grid, np.inf)
        gsxxs = np.zeros_like(stheta_grid)
        for i, stheta in enumerate(stheta_grid):
            hp.stheta = stheta
            hp.gsxx = np.sqrt(sv / sv_1) / hp.stheta
            gsxxs[i] = hp.gsxx
            # Figure out strongest constraint
            constr_mins[i] = np.min([fn() for fn in hp.constraints().values()])

        # Check if (mx, ms, sv) point is allowed for some (gsxx, stheta) combination
        return constr_mins.max()
def get_hp_relic(mx, ms, stheta, semi_analytic=True):
    """
    Get relic abundance for Higgs portal model where the DM annihilates
    into mediators.
    """
    def fn(log_gsxx):
        hp = HiggsPortal(mx, ms, 10**log_gsxx, stheta)
        hp.gsxx = 10**log_gsxx
        return relic_density(hp, semi_analytic) - Omega_dm

    if ms < mx:
        bracket = [-5, np.log10(4 * np.pi)]
    else:
        bracket = [-10, 10]
        
    print([fn(b) for b in bracket])
    
    sol = root_scalar(fn, bracket=bracket, xtol=1e-100, rtol=1e-4)
    if not sol.converged:
        raise Exception("didn't converge")

    gsxx_relic = 10**sol.root
    hp = HiggsPortal(mx, ms, gsxx_relic, stheta)
    sigma = hp.annihilation_cross_sections(
        2 * hp.mx * (1 + 1/2 * v_mw**2)
    )["total"]
    sv = sigma * v_mw * sv_inv_MeV_to_cm3_per_s
    b = sigma / v_mw
#     print()

    return sv, gsxx_relic, b
def get_sv_hp_relic(mx, ms, stheta=1):
    def fn(gsxx):
        hp = HiggsPortal(mx, ms, gsxx, stheta)
        return relic_density(hp) - 0.22
    
    if ms < mx:
        bracket = (1e-5, 4*np.pi)
    else:
        bracket = (1e-4, 1e8)
    gsxx = root_scalar(fn, bracket=bracket, xtol=1e-100, rtol=1e-3).root
    hp = HiggsPortal(mx, ms, gsxx, stheta)
    return hp.annihilation_cross_sections(
        2 * hp.mx * (1 + 0.5 * v_mw**2)
    )["total"] * v_mw * sv_inv_MeV_to_cm3_per_s
 def fn(mx, ms):
     hp = HiggsPortal(mx, ms, gsxx_max, 1)
     return hp.annihilation_cross_sections(
         2 * hp.mx * (1 + 0.5 * v_mw**2)
     )["total"] * v_mw * sv_inv_MeV_to_cm3_per_s
plt.axvline(ms, color="m")

# plt.axvline(0.5, color="r")
# plt.axvline(105, color="r")
plt.axvline(13, color="k", linestyle=":")
plt.axvline(20, color="k", linestyle=":")

plt.title(r"$m_\chi = %g \, \mathrm{MeV}$, $\sin\theta = 0.1$" % mx)
plt.legend(fontsize=8)
plt.ylim(1e-20, 1e-5)
plt.xlabel(r"$T$ [MeV]")
plt.ylabel(r"$\langle \sigma v \rangle_{\bar{\chi}\chi,0} / (g_{S\chi} \sin\theta)^2$ [cm$^3$/s]")

# %% hidden=true
hp = HiggsPortal(mx=0.1, ms=0.21000000000000002, gsxx=0.37387522703120396, stheta=0.01)
hp.annihilation_cross_sections(2 * hp.mx * (1 + 1/2 * 1e-3**2))
hp.thermal_cross_section(2)

# %% hidden=true

# %% hidden=true
v_mw = 0.001

def get_hp_relic(mx, ms, stheta, semi_analytic=True):
    """
    Get relic abundance for Higgs portal model where the DM annihilates
    into mediators.
    """
    def fn(log_gsxx):
        hp = HiggsPortal(mx, ms, 10**log_gsxx, stheta)
        hp.gsxx = 10**log_gsxx