from __future__ import division, print_function
import opticstools as ot
import numpy as np
import matplotlib.pyplot as plt

the_V = 10
the_j = 0
the_n = 0

plt.clf()
ax = plt.gca()
m33 = ot.mode_2d(the_V,10,j=the_j,n=the_n,sz=100)
plt.imshow(m33.real, extent=[-50*0.3,50*0.3,-50*0.3,50*0.3])
circle = plt.Circle((0, 0), 10, color='k', fill=False)
ax.add_artist(circle)
plt.xlabel('X pos (microns)')
plt.ylabel('Y pos (microns)')
plt.title("n={0:d} j={1:d} V={2:5.1f}".format(the_n, the_j, the_V))
llet_f = 4.64 * 1.1 #Lenslet focal length in mm
llet_w = 1.0  #Lenslet width in mm
nf = 20
f_ratios = 700 + np.arange(nf)*50
obstruct = 0.25
offset = 0.0e-6; label = 'Perfect Alignment'
#offset = 1.0e-6; label = '1 micron offset'
#offset = 2.0e-6; label = '2 microns offset'

llet_offset=0.67
#----

rad_pix = wave/(sz*m_pix)
m_pix_llet = rad_pix*llet_f/1e3
V = ot.compute_v_number(wave, core_diam/2, numerical_aperture)
fib_mode = ot.mode_2d(V, core_diam/2, j=0, n=0, sampling=m_pix,  sz=sz)
fib_angle = np.real(np.fft.fftshift(np.fft.fft2(np.fft.fftshift(fib_mode))))

llet = np.roll(ot.square(sz, llet_w/rad_pix/llet_f), int(llet_offset*llet_w/rad_pix/llet_f))
fib_mode = np.roll(fib_mode,int(offset/m_pix), axis=0)

mode = llet * fib_angle

couplings1 = []
couplings2 = []
for f_ratio in f_ratios:
    l_d_pix = f_ratio*wave/m_pix_llet 
    pup_diam_pix = sz/l_d_pix
    pup = ot.circle(sz, pup_diam_pix) - ot.circle(sz, pup_diam_pix*obstruct)
    psf = np.real(np.fft.fftshift(np.fft.fft2(np.fft.fftshift(pup))))
    psf_trunc = psf * llet