forked from astrofrog/wcsaxes-graticules
/
draw.py
89 lines (66 loc) · 2.38 KB
/
draw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import glob
import numpy as np
from astropy.wcs import WCS
from astropy.io import fits
import matplotlib.pyplot as plt
from matplotlib.collections import PathCollection
from graticules import get_lon_lat_path
class FITSWCSWrapper(object):
def __init__(self, wcs):
self.wcs = wcs
def world2pix(self, world):
return self.wcs.wcs_world2pix(world, 1)
def pix2world(self, pixel):
return self.wcs.wcs_pix2world(pixel, 1)
class PolarTrans(object):
def world2pix(self, r_theta):
r, theta = r_theta[:,0], r_theta[:,1]
x = r * np.cos(theta)
y = r * np.sin(theta)
return np.vstack([x, y]).transpose()
def pix2world(self, x_y):
x, y = x_y[:,0], x_y[:,1]
r = np.sqrt(x*x + y*y)
theta = np.atan2(y, x)
return np.vstack([r, theta]).transpose()
for filename in glob.glob(os.path.join('data', '*.fits')):
print(filename)
# Read in header and zoom out
header = fits.getheader(filename)
header['CDELT1'] *= 50
header['CDELT2'] *= 50
header['CRPIX1'] = header['NAXIS1'] / 2.
header['CRPIX2'] = header['NAXIS2'] / 2.
# Parse WCS transformation
wcs = FITSWCSWrapper(WCS(header))
# Set range of coordinates to draw grid
wmin = [-180., -90.]
wmax = [+180., +90.]
# Set transformation
trans = wcs
# Create figure
fig = plt.figure(figsize=(6,6))
ax = fig.add_axes([0.05, 0.05, 0.9, 0.9])
ax.set_xlim(0.5, header['NAXIS1'] + 0.5)
ax.set_ylim(0.5, header['NAXIS2'] + 0.5)
ax.text(0.5, 0.9, os.path.basename(filename), ha='center', transform=ax.transAxes, size=18)
# Define grid lines
NG = 18
N = 1000
paths = []
lon = np.linspace(wmin[0], wmax[0], N)
for latval in np.linspace(wmin[1], wmax[1], NG):
lat = np.repeat(latval, N)
lon_lat = np.vstack([lon, lat]).transpose()
paths.append(get_lon_lat_path(ax, trans, lon_lat))
lat = np.linspace(wmin[1], wmax[1], N)
for lonval in np.linspace(wmin[0], wmax[0], NG)[1:]:
lon = np.repeat(lonval, N)
lon_lat = np.vstack([lon, lat]).transpose()
paths.append(get_lon_lat_path(ax, trans, lon_lat))
ax.add_collection(PathCollection(paths, edgecolors='b', facecolors='none', alpha=0.4))
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
fig.savefig(filename.replace('.fits', '.png'))
plt.close(fig)