def fn(mx, ms, sv):
        assert ms > mx
        hp = HiggsPortal(mx, ms, 1, 1)
        sv_1 = hp.annihilation_cross_sections(
            2 * hp.mx * (1 + 0.5 * v_mw**2)
        )["total"] * v_mw * sv_inv_MeV_to_cm3_per_s

        # Find smallest stheta compatible with <sigma v> for the given
        # gsxx_max
        stheta_min = np.sqrt(sv / sv_1) / gsxx_max
        if stheta_min > 0.999:
            return -1e100
        
        stheta_grid = np.geomspace(stheta_min, 0.999, 20)
        constr_mins = np.full_like(stheta_grid, np.inf)
        gsxxs = np.zeros_like(stheta_grid)
        for i, stheta in enumerate(stheta_grid):
            hp.stheta = stheta
            hp.gsxx = np.sqrt(sv / sv_1) / hp.stheta
            gsxxs[i] = hp.gsxx
            # Figure out strongest constraint
            constr_mins[i] = np.min([fn() for fn in hp.constraints().values()])

        # Check if (mx, ms, sv) point is allowed for some (gsxx, stheta) combination
        return constr_mins.max()

# %% hidden=true
hp = HiggsPortal(mx=100, ms=1, gsxx=1, stheta=1.)
ms_over_mx = 1.1
if ms_over_mx < 1:  # xx->ss
    mx_grid = np.linspace(0.1, 1000, 60)
else:  # xx->SM
    mx_grid = np.linspace(0.1, 500, 60)
ms_grid = ms_over_mx * mx_grid
stheta_grid = np.geomspace(1e-6, 0.999, 50)
# constrs = hp.constrain("ms", ms_grid, "stheta", stheta_grid, "image")

# > 0: allowed, < 0: not allowed.
constrs = {
    c: np.full([len(stheta_grid), len(mx_grid)], np.inf) for c in hp.constraints()
}
for i in range(len(stheta_grid)):  # stheta
    for j in range(len(mx_grid)):  # mx, ms
        hp.mx = mx_grid[j]
        hp.ms = ms_grid[j]
        hp.stheta = stheta_grid[i]
        for c, fn in hp.constraints().items():
            constrs[c][i, j] = fn()

# %% code_folding=[] hidden=true
for (name, constr), color in zip(constrs.items(), mpl_colors):
#     if name == "higgs -> invis":
#         continue
    plt.contourf(
        mx_grid, stheta_grid, constr, levels=[-1e100, 0], colors=[color],