Exemple #1
0
    def likelihood(cube, ndim, nparams):
        mL = cube[0]
        t0 = cube[1]
        xS0 = np.array([cube[2], cube[3]])
        beta = cube[4]
        muL = np.array([cube[5], cube[6]])
        muS = np.array([cube[7], cube[8]])
        dL = cube[9]
        dLdS = cube[10]
        imag_base = cube[11]

        # Extra parameters
        dS = (1.0 * dL) / dLdS
        cube[12] = dS

        pspl = model.PSPL_parallax(raL, decL, mL, t0, xS0, beta,
                                   muL, muS, dL, dS, imag_base)

        cube[13] = pspl.tE
        cube[14] = pspl.thetaE_amp
        cube[15] = pspl.piE[0]
        cube[16] = pspl.piE[1]
        cube[17] = pspl.u0_amp
        cube[18] = pspl.muRel[0]
        cube[19] = pspl.muRel[1]

        lnL_phot = pspl.likely_photometry(t_phot, imag, imag_err)
        lnL_ast = pspl.likely_astrometry(t_ast, xpos, ypos, xpos_err, ypos_err)

        lnL = lnL_phot.mean() + lnL_ast.mean()

        fmt = 'mL={0:4.1f} t0={1:7.1f} xS0=[{2:8.4f}, {3:8.4f}] beta={4:5.2f} '
        fmt += 'muL=[{5:6.2f}, {6:6.2f}] muS=[{7:6.2f}, {8:6.2f}] dL={9:5.0f} dS={10:5.0f} '
        fmt += 'imag={11:4.1f} lnL={12:12.2e}'
        
        # print fmt.format(mL, t0, xS0[0], xS0[1], beta, muL[0], muL[1], muS[0], muS[1], dL, dS, imag_base, lnL)

        return lnL
def run_pspl_parallax_fit():
    data = load_data()

    # model_fitter.multinest_pspl_parallax(data,
    #                                      n_live_points=300,
    #                                      saveto='./mnest_pspl_par/',
    #                                      runcode='aa')

    model_fitter.plot_posteriors('mnest_pspl_par/', 'aa')

    best = model_fitter.get_best_fit('mnest_pspl_par/', 'aa')

    pspl_out = model.PSPL_parallax(data['raL'], data['decL'], best['mL'],
                                   best['t0'],
                                   np.array([best['xS0_E'],
                                             best['xS0_N']]), best['beta'],
                                   np.array([best['muL_E'], best['muL_N']]),
                                   np.array([best['muS_E'], best['muS_N']]),
                                   best['dL'], best['dS'], best['imag_base'])

    t_out = np.arange(55000, 58000, 1)
    imag_out = pspl_out.get_photometry(t_out)
    pos_out = pspl_out.get_astrometry(t_out)

    imag_out_data = pspl_out.get_photometry(data['t_phot'])
    pos_out_data = pspl_out.get_astrometry(data['t_ast'])

    lnL_phot_out = pspl_out.likely_photometry(data['t_phot'], data['imag'],
                                              data['imag_err'])
    lnL_ast_out = pspl_out.likely_astrometry(data['t_ast'], data['xpos'],
                                             data['ypos'], data['xpos_err'],
                                             data['ypos_err'])
    lnL_out = lnL_phot_out.mean() + lnL_ast_out.mean()

    chi2_phot = (((data['imag'] - imag_out_data) / data['imag_err'])**2).sum()
    chi2_ast = (((data['xpos'] - pos_out_data[:, 0]) /
                 data['xpos_err'])**2).sum()
    chi2_ast += (((data['ypos'] - pos_out_data[:, 1]) /
                  data['ypos_err'])**2).sum()

    chi2_tot = chi2_phot + chi2_ast

    dof_phot = len(data['imag']) - 12
    dof_ast = (len(data['xpos']) * 2) - 12
    dof_tot = (len(data['imag']) + 2 * len(data['xpos'])) - 12

    print 'lnL for output: ', lnL_out
    print 'Photometry chi^2 = {0:4.1f} (dof={1:4d})'.format(
        chi2_phot, dof_phot)
    print 'Astrometry chi^2 = {0:4.1f} (dof={1:4d})'.format(chi2_ast, dof_ast)
    print 'Total      chi^2 = {0:4.1f} (dof={1:4d})'.format(chi2_tot, dof_tot)

    outroot = 'mnest_pspl_par/plots/aa'

    ##########
    # Photometry vs. time
    ##########
    fig = plt.figure(1)
    plt.clf()
    f1 = fig.add_axes((0.2, 0.3, 0.75, 0.6))
    plt.errorbar(data['t_phot'], data['imag'], yerr=data['imag_err'], fmt='r.')
    plt.plot(t_out, imag_out, 'k-')
    plt.ylabel('I (mag)', fontsize=12)
    plt.title('Input Data and Output Model', fontsize=12)
    plt.xlim(data['t_phot'].min(), data['t_phot'].max())
    plt.gca().invert_yaxis()
    f1.set_xticklabels([])

    f2 = fig.add_axes((0.2, 0.15, 0.75, 0.15))
    plt.errorbar(data['t_phot'],
                 imag_out_data - data['imag'],
                 yerr=data['imag_err'],
                 fmt='r.')
    plt.axhline(0, linestyle='-', color='k')
    plt.xlim(data['t_phot'].min(), data['t_phot'].max())
    plt.ylim(-0.3, 0.3)
    plt.xlabel('t - t0 (days)', fontsize=12)
    plt.ylabel('Residual', fontsize=12)
    plt.gca().invert_yaxis()
    plt.savefig(outroot + '_phot.png')

    ##########
    # Astrometry 2D
    ##########
    fig = plt.figure(2)
    plt.clf()

    lens_pos = pspl_out.get_lens_astrometry(t_out)
    srce_pos_unlens = pspl_out.get_astrometry_unlensed(t_out)

    t0idx = np.argmin(np.abs(t_out - pspl_out.t0))
    x0 = lens_pos[t0idx, 0]
    y0 = lens_pos[t0idx, 1]

    plt.errorbar((data['xpos'] - x0) * 1e3, (data['ypos'] - y0) * 1e3,
                 xerr=data['xpos_err'] * 1e3,
                 yerr=data['ypos_err'] * 1e3,
                 fmt='r.')
    plt.plot((pos_out[:, 0] - x0) * 1e3, (pos_out[:, 1] - y0) * 1e3,
             'k-',
             label='Source (lensed)')
    plt.plot((srce_pos_unlens[:, 0] - x0) * 1e3,
             (srce_pos_unlens[:, 1] - y0) * 1e3,
             'k--',
             label='Source (unlensed)')
    plt.plot((lens_pos[:, 0] - x0) * 1e3, (lens_pos[:, 1] - y0) * 1e3,
             'b--',
             label='Lens')
    plt.legend(fontsize=12)
    plt.gca().invert_xaxis()
    plt.xlabel(r'$\Delta$RA (mas)', fontsize=12)
    plt.ylabel(r'$\Delta$Dec (mas)', fontsize=12)
    plt.xlim(5, -10)
    plt.ylim(-10, 5)
    plt.title('Input Data and Output Model', fontsize=12)
    plt.savefig(outroot + '_ast.png')

    ##########
    # Astrometry East vs. time
    ##########
    fig = plt.figure(3)
    plt.clf()
    f1 = fig.add_axes((0.2, 0.3, 0.75, 0.6))
    plt.errorbar(data['t_ast'], (data['xpos'] - x0) * 1e3,
                 yerr=data['xpos_err'] * 1e3,
                 fmt='r.')
    plt.plot(t_out, (pos_out[:, 0] - x0) * 1e3, 'k-')
    plt.xlim(56000, 57600)
    plt.ylabel(r'$\Delta$RA (mas)', fontsize=12)
    plt.title('Input Data and Output Model', fontsize=12)
    f1.set_xticklabels([])

    f2 = fig.add_axes((0.2, 0.15, 0.75, 0.15))
    plt.errorbar(data['t_ast'], (pos_out_data[:, 0] - data['xpos']) * 1e3,
                 yerr=data['xpos_err'] * 1e3,
                 fmt='r.')
    plt.axhline(0, linestyle='--', color='k')
    plt.xlim(56000, 57600)
    plt.ylim(-1.5, 1.5)
    plt.xlabel('t - t0 (days)', fontsize=12)
    plt.ylabel('Residuals (mas)', fontsize=12)
    plt.savefig(outroot + '_t_vs_E.png')

    ##########
    # Astrometry East vs. time
    ##########
    fig = plt.figure(4)
    plt.clf()
    f1 = fig.add_axes((0.2, 0.3, 0.75, 0.6))
    plt.errorbar(data['t_ast'], (data['ypos'] - y0) * 1e3,
                 yerr=data['ypos_err'] * 1e3,
                 fmt='r.')
    plt.plot(t_out, (pos_out[:, 1] - y0) * 1e3, 'k-')
    plt.xlim(56000, 57600)
    plt.ylabel(r'$\Delta$Dec (mas)', fontsize=12)
    plt.title('Input Data and Output Model', fontsize=12)
    f1.set_xticklabels([])

    f2 = fig.add_axes((0.2, 0.15, 0.75, 0.15))
    plt.errorbar(data['t_ast'], (pos_out_data[:, 1] - data['ypos']) * 1e3,
                 yerr=data['ypos_err'] * 1e3,
                 fmt='r.')
    plt.axhline(0, linestyle='--', color='k')
    plt.xlim(56000, 57600)
    plt.ylim(-1.5, 1.5)
    plt.xlabel('t - t0 (days)', fontsize=12)
    plt.ylabel('Residuals (mas)', fontsize=12)
    plt.savefig(outroot + '_t_vs_N.png')

    return
Exemple #3
0
def fake_data_parallax(raL_in,
                       decL_in,
                       mL_in,
                       t0_in,
                       xs0_in,
                       beta_in,
                       muS_in,
                       muL_in,
                       dL_in,
                       dS_in,
                       imag_in,
                       outdir=''):

    pspl_in = model.PSPL_parallax(raL_in, decL_in, mL_in, t0_in, xS0_in,
                                  beta_in, muL_in, muS_in, dL_in, dS_in,
                                  imag_in)

    # Simulate
    # photometric observations every 1 day and
    # astrometric observations every 14 days
    # for the bulge observing window. Observations missed
    # for 125 days out of 365 days for photometry and missed
    # for 245 days out of 365 days for astrometry.
    t_phot = np.array([], dtype=float)
    t_ast = np.array([], dtype=float)
    for year_start in np.arange(56000, 58000, 365.25):
        phot_win = 240.0
        phot_start = (365.25 - phot_win) / 2.0
        t_phot_new = np.arange(year_start + phot_start,
                               year_start + phot_start + phot_win, 1)
        t_phot = np.concatenate([t_phot, t_phot_new])

        ast_win = 120.0
        ast_start = (365.25 - ast_win) / 2.0
        t_ast_new = np.arange(year_start + ast_start,
                              year_start + ast_start + ast_win, 14)
        t_ast = np.concatenate([t_ast, t_ast_new])

    # Make the photometric observations.
    # Assume 0.05 mag photoemtric errors at I=19.
    # This means Signal = 400 e- at I=19.
    flux0 = 4000.0
    imag0 = 19.0
    imag_obs = pspl_in.get_photometry(t_phot)
    flux_obs = flux0 * 10**((imag_obs - imag0) / -2.5)
    flux_obs_err = flux_obs**0.5
    flux_obs += np.random.randn(len(t_phot)) * flux_obs_err
    imag_obs = -2.5 * np.log10(flux_obs / flux0) + imag0
    imag_obs_err = 1.087 / flux_obs_err

    # Make the astrometric observations.
    # Assume 0.15 milli-arcsec astrometric errors in each direction at all epochs.
    pos_obs_tmp = pspl_in.get_astrometry(t_ast)
    pos_obs_err = np.ones((len(t_ast), 2), dtype=float) * 0.01 * 1e-3
    pos_obs = pos_obs_tmp + pos_obs_err * np.random.randn(len(t_ast), 2)

    plt.figure(1)
    plt.clf()
    plt.errorbar(t_phot, imag_obs, yerr=imag_obs_err, fmt='k.')
    plt.xlabel('t - t0 (days)')
    plt.ylabel('I (mag)')
    plt.title('Input Data and Model')
    plt.savefig(outdir + 'fake_data_phot.png')

    plt.figure(2)
    plt.clf()
    plt.errorbar(pos_obs[:, 0],
                 pos_obs[:, 1],
                 xerr=pos_obs_err[:, 0],
                 yerr=pos_obs_err[:, 1],
                 fmt='k.')
    plt.gca().invert_xaxis()
    plt.xlabel('X Pos (")')
    plt.ylabel('Y Pos (")')
    plt.plot(pos_obs_tmp[:, 0], pos_obs_tmp[:, 1], 'r--')
    plt.title('Input Data and Model')
    plt.savefig(outdir + 'fake_data_ast.png')

    plt.figure(3)
    plt.clf()
    plt.errorbar(t_ast, pos_obs[:, 0], yerr=pos_obs_err[:, 0], fmt='k.')
    plt.plot(t_ast, pos_obs_tmp[:, 0], 'r--')
    plt.xlabel('t - t0 (days)')
    plt.ylabel('X Pos (")')
    plt.title('Input Data and Model')
    plt.savefig(outdir + 'fake_data_t_vs_E.png')

    plt.figure(4)
    plt.clf()
    plt.errorbar(t_ast, pos_obs[:, 1], yerr=pos_obs_err[:, 1], fmt='k.')
    plt.plot(t_ast, pos_obs_tmp[:, 1], 'r--')
    plt.xlabel('t - t0 (days)')
    plt.ylabel('Y Pos (")')
    plt.title('Input Data and Model')
    plt.savefig(outdir + 'fake_data_t_vs_N.png')

    data = {}
    data['t_phot'] = t_phot
    data['imag'] = imag_obs
    data['imag_err'] = imag_obs_err

    data['t_ast'] = t_ast
    data['xpos'] = pos_obs[:, 0]
    data['ypos'] = pos_obs[:, 1]
    data['xpos_err'] = pos_obs_err[:, 0]
    data['ypos_err'] = pos_obs_err[:, 1]
    data['raL'] = raL_in
    data['decL'] = decL_in

    params = {}
    params['raL'] = raL_in
    params['decL'] = decL_in
    params['mL'] = mL_in
    params['t0'] = t0_in
    params['xS0_E'] = xS0_in[0]
    params['xS0_N'] = xS0_in[1]
    params['beta'] = beta_in
    params['muS_E'] = muS_in[0]
    params['muS_N'] = muS_in[1]
    params['muL_E'] = muL_in[0]
    params['muL_N'] = muL_in[1]
    params['dL'] = dL_in
    params['dS'] = dS_in
    params['imag_base'] = imag_in

    return data, params
Exemple #4
0
def test_pspl_parallax_fit():
    data, p_in = fake_data_parallax_lmc()

    model_fitter.multinest_pspl_parallax(data,
                                         n_live_points=300,
                                         saveto='./mnest_pspl_par_lmc/',
                                         runcode='aa')

    best = model_fitter.get_best_fit('mnest_pspl_par/', 'aa')

    pspl_out = model.PSPL_parallax(p_in['raL'], p_in['decL'], best['mL'],
                                   best['t0'],
                                   np.array([best['xS0_E'],
                                             best['xS0_N']]), best['beta'],
                                   np.array([best['muL_E'], best['muL_N']]),
                                   np.array([best['muS_E'], best['muS_N']]),
                                   best['dL'], best['dS'], best['imag_base'])

    pspl_in = model.PSPL_parallax(p_in['raL'], p_in['decL'], p_in['mL'],
                                  p_in['t0'],
                                  np.array([p_in['xS0_E'],
                                            p_in['xS0_N']]), p_in['beta'],
                                  np.array([p_in['muL_E'], p_in['muL_N']]),
                                  np.array([p_in['muS_E'], p_in['muS_N']]),
                                  p_in['dL'], p_in['dS'], p_in['imag_base'])

    p_in['tE'] = pspl_in.tE
    p_in['thetaE'] = pspl_in.thetaE_amp
    p_in['piE_E'] = pspl_in.piE[0]
    p_in['piE_N'] = pspl_in.piE[1]
    p_in['u0_amp'] = pspl_in.u0_amp
    p_in['muRel_E'] = pspl_in.muRel[0]
    p_in['muRel_N'] = pspl_in.muRel[1]

    model_fitter.plot_posteriors('mnest_pspl_par/', 'aa', sim_vals=p_in)

    imag_out = pspl_out.get_photometry(data['t_phot'])
    pos_out = pspl_out.get_astrometry(data['t_ast'])

    imag_in = pspl_in.get_photometry(data['t_phot'])
    pos_in = pspl_in.get_astrometry(data['t_ast'])

    lnL_phot_out = pspl_out.likely_photometry(data['t_phot'], data['imag'],
                                              data['imag_err'])
    lnL_ast_out = pspl_out.likely_astrometry(data['t_ast'], data['xpos'],
                                             data['ypos'], data['xpos_err'],
                                             data['ypos_err'])
    lnL_out = lnL_phot_out.mean() + lnL_ast_out.mean()

    lnL_phot_in = pspl_in.likely_photometry(data['t_phot'], data['imag'],
                                            data['imag_err'])
    lnL_ast_in = pspl_in.likely_astrometry(data['t_ast'], data['xpos'],
                                           data['ypos'], data['xpos_err'],
                                           data['ypos_err'])
    lnL_in = lnL_phot_in.mean() + lnL_ast_in.mean()

    print('lnL for input: ', lnL_in)
    print('lnL for output: ', lnL_out)

    outroot = 'mnest_pspl_par/plots/aa'

    plot_mnest_test(data, imag_in, imag_out, pos_in, pos_out, outroot)

    return
Exemple #5
0
def test_pspl_parallax_boden1998(t0):
    """
    I can get this one to match Figure 6 of Boden et al. 1998.
    """

    # Scenarios from Paczynski 1998
    raL = 80.89375  # LMC R.A.
    decL = -71.74  # LMC Dec. This is the sin \beta = -0.99 where \beta =
    mL = 0.1  # msun
    xS0 = np.array([0.000, 0.088e-3])  # arcsec
    beta = -0.16  # mas  same as p=0.4
    muS = np.array([-2.0, 1.5])
    muL = np.array([0.0, 0.0])
    dL = 8e3  # 10 kpc
    dS = 50e3  # 50 kpc in LMC
    imag = 19.0

    # No parallax
    pspl_n = model.PSPL(mL, t0, xS0, beta, muL, muS, dL, dS, imag)
    print('pspl_n.u0', pspl_n.u0)
    print('pspl_n.muS', pspl_n.muS)
    print('pspl_n.u0_hat', pspl_n.u0_hat)
    print('pspl_n.thetaE_hat', pspl_n.thetaE_hat)

    # With parallax
    pspl_p = model.PSPL_parallax(raL, decL, mL, t0, xS0, beta, muL, muS, dL,
                                 dS, imag)

    #t = np.arange(56000, 58000, 1)
    t = np.arange(t0 - 500, t0 + 500, 1)
    dt = t - pspl_n.t0

    A_n = pspl_n.get_amplification(t)
    A_p = pspl_p.get_amplification(t)

    xS_n = pspl_n.get_astrometry(t)
    xS_p_unlens = pspl_p.get_astrometry_unlensed(t)
    xS_p_lensed = pspl_p.get_astrometry(t)
    xL_p_unlens = pspl_p.get_lens_astrometry(t)

    thetaS = (xS_p_unlens - xL_p_unlens) * 1e3  # mas
    u = thetaS / pspl_p.tE
    thetaS_lensed = (xS_p_lensed - xL_p_unlens) * 1e3  # mas

    shift_n = pspl_n.get_centroid_shift(t)  # mas
    shift_p = (xS_p_lensed - xS_p_unlens) * 1e3  # mas
    shift_n_amp = np.linalg.norm(shift_n, axis=1)
    shift_p_amp = np.linalg.norm(shift_p, axis=1)

    # Plot the amplification
    fig1 = plt.figure(1)
    plt.clf()
    f1_1 = fig1.add_axes((0.1, 0.3, 0.8, 0.6))
    plt.plot(dt, 2.5 * np.log10(A_n), 'b-', label='No parallax')
    plt.plot(dt, 2.5 * np.log10(A_p), 'r-', label='Parallax')
    plt.legend()
    plt.ylabel('2.5 * log(A)')
    f1_1.set_xticklabels([])

    f2_1 = fig1.add_axes((0.1, 0.1, 0.8, 0.2))
    plt.plot(dt,
             2.5 * (np.log10(A_p) - np.log10(A_n)),
             'k-',
             label='Par - No par')
    plt.axhline(0, linestyle='--', color='k')
    plt.legend()
    plt.ylabel('Diff')
    plt.xlabel('t - t0 (MJD)')

    idx = np.argmin(np.abs(t - t0))

    # Plot the positions of everything
    fig2 = plt.figure(2)
    plt.clf()
    plt.plot(xS_n[:, 0],
             xS_n[:, 1],
             'r--',
             mfc='none',
             mec='red',
             label='No parallax model')
    plt.plot(xS_p_unlens[:, 0],
             xS_p_unlens[:, 1],
             'b--',
             mfc='blue',
             mec='blue',
             label='Parallax model, unlensed')
    plt.plot(xS_p_lensed[:, 0],
             xS_p_lensed[:, 1],
             'b-',
             label='Parallax model, lensed')
    plt.plot(xL_p_unlens[:, 0],
             xL_p_unlens[:, 1],
             'g--',
             mfc='none',
             mec='green',
             label='Parallax model, Lens')
    plt.plot(xS_n[idx, 0], xS_n[idx, 1], 'rx')
    plt.plot(xS_p_unlens[idx, 0], xS_p_unlens[idx, 1], 'bx')
    plt.plot(xS_p_lensed[idx, 0], xS_p_lensed[idx, 1], 'bx')
    plt.plot(xL_p_unlens[idx, 0], xL_p_unlens[idx, 1], 'gx')
    plt.legend()
    plt.gca().invert_xaxis()
    # lim = 0.05
    # plt.xlim(lim, -lim) # arcsec
    # plt.ylim(-lim, lim)
    # plt.xlim(0.006, -0.006) # arcsec
    # plt.ylim(-0.02, 0.02)
    plt.xlabel('R.A. (")')
    plt.ylabel('Dec. (")')

    # Check just the astrometric shift part.
    fig3 = plt.figure(3)
    plt.clf()
    f1_3 = fig3.add_axes((0.2, 0.3, 0.7, 0.6))
    plt.plot(dt, shift_n_amp, 'r--', label='No parallax model')
    plt.plot(dt, shift_p_amp, 'b--', label='Parallax model')
    plt.legend(fontsize=10)
    plt.ylabel('Astrometric Shift (mas)')
    f1_3.set_xticklabels([])

    f2_3 = fig3.add_axes((0.2, 0.1, 0.7, 0.2))
    plt.plot(dt, shift_p_amp - shift_n_amp, 'k-', label='Par - No par')
    plt.legend()
    plt.axhline(0, linestyle='--', color='k')
    plt.xlabel('t - t0 (MJD)')
    plt.ylabel('Res.')

    fig4 = plt.figure(4)
    plt.clf()
    plt.plot(shift_n[:, 0], shift_n[:, 1], 'r-', label='No parallax')
    plt.plot(shift_p[:, 0], shift_p[:, 1], 'b-', label='Parallax')
    plt.axhline(0, linestyle='--')
    plt.axvline(0, linestyle='--')
    plt.gca().invert_xaxis()
    plt.legend(loc='upper left')
    plt.xlabel('Shift RA (mas)')
    plt.ylabel('Shift Dec (mas)')
    plt.axis('equal')

    plt.figure(5)
    plt.clf()
    plt.plot(thetaS[:, 0], shift_p[:, 0], 'r-', label='RA')
    plt.plot(thetaS[:, 1], shift_p[:, 1], 'b-', label='Dec')
    plt.xlabel('thetaS (")')
    plt.ylabel('Shift (mas)')

    plt.figure(6)
    plt.clf()
    plt.plot(thetaS[:, 0], thetaS[:, 1], 'r-', label='Unlensed')
    plt.plot(thetaS_lensed[:, 0], thetaS_lensed[:, 1], 'b-', label='Lensed')
    plt.axvline(0, linestyle='--', color='k')
    plt.legend()
    plt.xlabel('thetaS_E (")')
    plt.ylabel('thetaS_N (")')

    print('Einstein radius: ', pspl_n.thetaE_amp, pspl_p.thetaE_amp)
    print('Einstein crossing time: ', pspl_n.tE, pspl_n.tE)

    return
Exemple #6
0
def test_pspl_parallax(raL,
                       decL,
                       mL,
                       t0,
                       xS0,
                       beta,
                       muS,
                       muL,
                       dL,
                       dS,
                       imag,
                       outdir=''):

    # No parallax
    pspl_n = model.PSPL(mL, t0, xS0, beta, muL, muS, dL, dS, imag)
    print('pspl_n.u0', pspl_n.u0)
    print('pspl_n.muS', pspl_n.muS)
    print('pspl_n.u0_hat', pspl_n.u0_hat)
    print('pspl_n.thetaE_hat', pspl_n.thetaE_hat)

    # With parallax
    pspl_p = model.PSPL_parallax(raL, decL, mL, t0, xS0, beta, muL, muS, dL,
                                 dS, imag)

    t = np.arange(t0 - 500, t0 + 500, 1)
    dt = t - pspl_n.t0

    A_n = pspl_n.get_amplification(t)
    A_p = pspl_p.get_amplification(t)

    xS_n = pspl_n.get_astrometry(t)
    xS_p_unlens = pspl_p.get_astrometry_unlensed(t)
    xS_p_lensed = pspl_p.get_astrometry(t)

    # Plot the amplification
    fig1 = plt.figure(1)
    plt.clf()
    f1_1 = fig1.add_axes((0.20, 0.3, 0.75, 0.6))
    plt.plot(dt, 2.5 * np.log10(A_n), 'b-', label='No parallax')
    plt.plot(dt, 2.5 * np.log10(A_p), 'r-', label='Parallax')
    plt.legend(fontsize=10)
    plt.ylabel('2.5 * log(A)')
    f1_1.set_xticklabels([])

    f2_1 = fig1.add_axes((0.20, 0.1, 0.75, 0.2))
    plt.plot(dt,
             2.5 * (np.log10(A_p) - np.log10(A_n)),
             'k-',
             label='Par - No par')
    plt.axhline(0, linestyle='--', color='k')
    plt.legend(fontsize=10)
    plt.ylabel('Diff')
    plt.xlabel('t - t0 (MJD)')

    plt.savefig(outdir + 'amp_v_time.png')
    print("save to " + outdir)

    # Plot the positions of everything
    fig2 = plt.figure(2)
    plt.clf()
    plt.plot(xS_n[:, 0] * 1e3,
             xS_n[:, 1] * 1e3,
             'r--',
             mfc='none',
             mec='red',
             label='No parallax model')
    plt.plot(xS_p_unlens[:, 0] * 1e3,
             xS_p_unlens[:, 1] * 1e3,
             'b--',
             mfc='none',
             mec='blue',
             label='Parallax model, unlensed')
    plt.plot(xS_p_lensed[:, 0] * 1e3,
             xS_p_lensed[:, 1] * 1e3,
             'b-',
             label='Parallax model, lensed')
    plt.legend(fontsize=10)
    # plt.gca().invert_xaxis()
    # lim = 0.05
    # plt.xlim(lim, -lim) # arcsec
    # plt.ylim(-lim, lim)
    # plt.xlim(0.006, -0.006) # arcsec
    # plt.ylim(-0.02, 0.02)
    plt.axis('equal')
    plt.xlabel('R.A. (mas)')
    plt.ylabel('Dec. (mas)')
    plt.savefig(outdir + 'on_sky.png')

    # Check just the astrometric shift part.
    shift_n = pspl_n.get_centroid_shift(t)  # mas
    shift_p = (xS_p_lensed - xS_p_unlens) * 1e3  # mas
    shift_n_amp = np.linalg.norm(shift_n, axis=1)
    shift_p_amp = np.linalg.norm(shift_p, axis=1)

    fig3 = plt.figure(3)
    plt.clf()
    f1_3 = fig3.add_axes((0.20, 0.3, 0.75, 0.6))
    plt.plot(dt, shift_n_amp, 'r--', label='No parallax model')
    plt.plot(dt, shift_p_amp, 'b--', label='Parallax model')
    plt.ylabel('Astrometric Shift (mas)')
    plt.legend(fontsize=10)
    f1_3.set_xticklabels([])

    f2_3 = fig3.add_axes((0.20, 0.1, 0.75, 0.2))
    plt.plot(dt, shift_p_amp - shift_n_amp, 'k-', label='Par - No par')
    plt.legend(fontsize=10)
    plt.axhline(0, linestyle='--', color='k')
    plt.ylabel('Diff (mas)')
    plt.xlabel('t - t0 (MJD)')

    plt.savefig(outdir + 'shift_amp_v_t.png')

    fig4 = plt.figure(4)
    plt.clf()
    plt.plot(shift_n[:, 0], shift_n[:, 1], 'r-', label='No parallax')
    plt.plot(shift_p[:, 0], shift_p[:, 1], 'b-', label='Parallax')
    plt.axhline(0, linestyle='--')
    plt.axvline(0, linestyle='--')
    plt.gca().invert_xaxis()
    plt.legend(fontsize=10)
    plt.xlabel('Shift RA (mas)')
    plt.ylabel('Shift Dec (mas)')
    plt.axis('equal')
    plt.savefig(outdir + 'shift_on_sky.png')

    print('Einstein radius: ', pspl_n.thetaE_amp, pspl_p.thetaE_amp)
    print('Einstein crossing time: ', pspl_n.tE, pspl_n.tE)

    return
def plot_astrometry_lens_ref():
    """
    Plot the astrometry as seen on the sky when in the
    rest frame of the lens.
    """
    # Scenario from Belokurov and Evans 2002 (Figure 1)
    raL = 17.5
    decL = -30.0
    mL = 10.0  # msun
    t0 = 57650.0
    xS0 = np.array([0.000, 0.003])
    beta = -3.0  # mas
    muS = np.array([-8.0, 0.0])
    muL = np.array([0.0, 0.0])
    dL = 3000.0
    dS = 6000.0
    imag = 19.0

    # With parallax
    pspl_p = model.PSPL_parallax(raL, decL, mL, t0, xS0, beta, muL, muS, dL,
                                 dS, imag)

    # In Days.
    t = np.arange(t0 - 5000, t0 + 5000, 10)
    dt = t - pspl_p.t0

    A_p = pspl_p.get_amplification(t)
    i_p = pspl_p.get_photometry(t)

    xS_p_unlens = pspl_p.get_astrometry_unlensed(t)
    xS_p_lensed = pspl_p.get_astrometry(t)
    xL_p = pspl_p.get_lens_astrometry(t)

    # Plot the amplification
    plt.close(1)
    fig = plt.figure(1, figsize=(18, 5.67))
    plt.clf()
    plt.subplots_adjust(left=0.065, right=0.98, bottom=0.16, wspace=0.34)
    ax1 = fig.add_subplot(131)
    ax2 = fig.add_subplot(132)
    ax3 = fig.add_subplot(133)
    ax1.invert_yaxis()
    ax1.plot(dt, i_p, 'r-')
    ax1.set_ylabel('I-band (mag)')
    ax1.set_xlabel(r'$t - t_0$ (days)')
    ax1.set_xlim(-1000, 1000)

    # Plot the positions of everything
    ax2.plot(xL_p[:, 0] * 1e3,
             xL_p[:, 1] * 1e3,
             'k--',
             mfc='none',
             mec='grey',
             label='Lens')
    ax2.plot(xS_p_unlens[:, 0] * 1e3,
             xS_p_unlens[:, 1] * 1e3,
             'b--',
             mfc='none',
             mec='blue',
             label='Src, unlensed')
    ax2.plot(xS_p_lensed[:, 0] * 1e3,
             xS_p_lensed[:, 1] * 1e3,
             'r-',
             label='Src, lensed')
    ax2.legend(fontsize=10)
    ax2.invert_xaxis()
    ax2.set_xlabel('R.A. (mas)')
    ax2.set_ylabel('Dec. (mas)')
    ax2.axis('equal')
    lim = 18
    ax2.set_xlim(lim, -lim)  # arcsec
    ax2.set_ylim(-lim, lim)

    # Check just the astrometric shift part.
    shift_p = (xS_p_lensed - xS_p_unlens) * 1e3  # mas
    shift_p_amp = np.linalg.norm(shift_p, axis=1)

    ax3.plot(shift_p[:, 0], shift_p[:, 1], 'r-')
    ax3.axhline(0, linestyle='--', color='grey')
    ax3.axvline(0, linestyle='--', color='grey')
    ax3.invert_xaxis()
    ax3.set_xlabel(r'$\Delta$R.A. (mas)')
    ax3.set_ylabel(r'$\Delta$Dec. (mas)')
    ax3.axis('equal')

    plt.savefig(plot_dir + 'phot_astrom.png')

    print('Maximum astrometric sift: {0:.2f} mas'.format(shift_p_amp.max()))

    return