Ejemplo n.º 1
0
def plot_widths(width_data, canvassize=None, msize=None, \
                     xval='year', yunits='phase', \
                     xticks=True, yticks=True, xlabel=True, ylabel=True, \
                     sym='o', colour=None, mf_colour=None, me_colour=None, \
                     csize=3, xlim=None, ylim=None, \
                     figtext=None, gridlines=None,
                     ticklabelsize=18, axislabelsize=18):

     possible_xval = ['mjd', 'mjd0', 'year', 'serial']
     if(possible_xval.count(xval) == 0):
          print "There is not value "+xval+ \
              " in the data that we can plot on x axis. Exiting."
          exit()

     possible_yunits = ['phase', 'deg', 'rad', 'cos_phi']
     if(possible_yunits.count(yunits) == 0):
          print yunits+ " is not an option for y axis units. Exiting."
          exit()

          

     # If we have just the one set of residuals, cast it as a one-element
     # list to make things nice and general
     if(type(width_data) is dict):
          width_data = [width_data] 

# mjd0 will subtract the nearest 100 of the smallest mjd value from the 
# mjd arrays
     if(xval=='mjd0'):
          min_mjd = np.array([])
          for width in width_data:
               min_mjd = np.append(min_mjd, np.amin(width['mjd']))
          mjdint = np.floor(np.amin(min_mjd))
          for width in width_data:
               width['mjd0'] = width['mjd'] - mjdint
      
     if(xval=='year'):
          for width in width_data:
               # date_out = [mjd.mjdtodate(m) for m in width['mjd']]
               width['year'] = [mjd.mjdtoyear(m) for m in width['mjd']]
               # width['year'] = [d.year + d.day/365. + \
               #                    d.hour/(365.*24.) + \
               #                    d.minute/(365.*24.*60.) + \
               #                    d.second/(365.*24.*60.*60.) \
               #                    for d in date_out]
     
# Set up plot limits now
     xmin = np.amin(width_data[0][xval])-0.003
     xmax = np.amax(width_data[0][xval])+0.003
     ymin = np.amin(width_data[0]['width'] - width_data[0]['werr'])
     ymax = np.amax(width_data[len(width_data)-1]['width'] + \
                         width_data[len(width_data)-1]['werr'])
     xspan = abs(xmax - xmin)
     yspan = abs(ymax - ymin)
          

# Set up the plot:
     fig = plt.figure(figsize=canvassize)
     ax = fig.add_axes([0.12, 0.14, 0.86, 0.83])
     ax.xaxis.set_tick_params(labelsize=ticklabelsize, pad=8)
     ax.yaxis.set_tick_params(labelsize=ticklabelsize, pad=8)
     if(xlim==None):
          ax.set_xlim(xmin - 0.01*xspan, xmax + 0.01*xspan)
     else:
          ax.set_xlim(xlim)
     if(ylim==None):
          ax.set_ylim(ymin - 0.01*yspan, ymax + 0.02*yspan)
     else:
          ax.set_ylim(ylim)

     if (xlabel):          
          if(xval=='serial'):               
               ax.set_xlabel('Serial number', fontsize=axislabelsize, labelpad=12)
          elif(xval=='mjd'):
               ax.set_xlabel('MJD', fontsize=axislabelsize, labelpad=12)
          elif(xval=='mjd0'):
               ax.set_xlabel('MJD - {:d}'.format(int(mjdint)), fontsize=axislabelsize, labelpad=12)
          elif(xval=='year'):
               # Set formatting for years so that they have %d formatting:
               xmajorFormatter = FormatStrFormatter('%d')
               ax.set_xlabel('Year', fontsize=axislabelsize, labelpad=12)
               ax.xaxis.set_major_formatter(xmajorFormatter)
               
     if (ylabel):
          ax.set_ylabel('Pulse width (degrees)', fontsize=axislabelsize)#, labelpad=8)

     if(not xticks):
          for tick in ax.xaxis.get_major_ticks():
               tick.label1On = False
               tick.label2On = False

     if(not yticks):
          for tick in ax.yaxis.get_major_ticks():
               tick.label1On = False
               tick.label2On = False

     

     for i_width in range(len(width_data)):
          width = width_data[i_width]
          # Get colours
          if(colour!=None):
              if (type(colour) is list):
                  clr = colour[i_width]
              else:
                  clr = colour
          else:
              # Set up automated colours
              if (len(width_data)==1):
                  clr = 'black'
              else:
                  clr = cm.gist_heat(float(i_width)/float(len(width_data))) 
              
          if (type(mf_colour) is list):
               mf_clr = mf_colour[i_width]
          else:
               mf_clr = clr
          if (type(me_colour) is list):
               me_clr = me_colour[i_width]
          else:
               me_clr = clr
          #ax.plot(res[xval], res['res'], 'o', markersize=msize, color=col)
          if(gridlines!=None):
               for ycoord in gridlines:
                    ax.axhline(ycoord, linestyle='--', color='black', \
                                    linewidth=0.4)
# Change to appropriate units: 
          if (yunits=='deg'):
               # Set formatting for degrees so that they have correct formatting:
               ymajorFormatter = FormatStrFormatter('%5.1f')
               ax.yaxis.set_major_formatter(ymajorFormatter)
               y_plot = width['width']*360.
               yerr_plot = width['werr']*360.
          elif (yunits=='rad'):
               y_plot = width['width']*2.*np.pi
               yerr_plot = width['werr']*2.*np.pi               
          elif (yunits=='cos_phi'):
               y_plot = cos(width['width']/2.)
               yerr_plot = cos(width['werr']/2.)
          else: # Just in units of phase as given in data file
               y_plot = width['width']
               yerr_plot = width['werr']

          if(i_width==0):
               xmin = np.amin(width[xval])
               xmax = np.amax(width[xval])
               ymin = np.amin(y_plot-yerr_plot)
               ymax = np.amax(y_plot+yerr_plot)               
          else:
               xmin = np.amin(np.append(width[xval], xmin))
               xmax = np.amax(np.append(width[xval], xmax))
               ymin = np.amin(np.append(y_plot-yerr_plot, ymin))
               ymax = np.amax(np.append(y_plot+yerr_plot, ymax))

          xspan = abs(xmax - xmin)
          yspan = abs(ymax - ymin)



# Overplot error bars.  Use fmt=None to tell it not to plot points:
          ax.plot(width[xval], y_plot, sym, color=clr, mfc=mf_clr, mec=me_clr)
          ax.errorbar(width[xval], y_plot, yerr=yerr_plot, \
                           capsize=csize, fmt=None, ecolor=clr, \
                           markersize=msize)
          
     if(xlim==None):
          ax.set_xlim(xmin - 0.025*xspan, xmax + 0.025*xspan)
     else:
          ax.set_xlim(xlim)
     if(ylim==None):
          ax.set_ylim(ymin - 0.1*yspan, ymax + 0.1*yspan)
     else:
          ax.set_ylim(ylim)

# Figure text must be a list of tuples: [(x, y, text), (x, y, text), ...]
     if(figtext!=None):
          for txt in figtext:
               ax.text(txt[0], txt[1], txt[2], fontsize=10, \
                            horizontalalignment='center', \
                            verticalalignment='center')


#     plt.savefig('test_widths.png')

     return ax
Ejemplo n.º 2
0
def plot_dmx(dmx_data, canvassize=None, msize=None, 
             xval='year', 
             xticks=True, yticks=True, xlabel=True, ylabel=True, 
             sym='o', symsize=3.4, 
             colour='black', mf_colour=None, me_colour=None,
             csize=2, xlim=None, ylim=None,
             figtext=None, gridlines=None):

     possible_xval = ['mjd', 'mjd0', 'year', 'serial']
     if(possible_xval.count(xval) == 0):
          print "There is not value "+xval+ \
              " in the data that we can plot on x axis. Exiting."
          exit()

     # If we have just the one set of residuals, cast it as a one-element
     # list to make things nice and general
     if(type(dmx_data) is dict):
          dmx_data = [dmx_data] 

# mjd0 will subtract the nearest 100 of the smallest mjd value from the 
# mjd arrays
     if(xval=='mjd0'):
          min_mjd = np.array([])
          for dmx in dmx_data:
               min_mjd = np.append(min_mjd, np.amin(dmx['mjd']))
          mjdint = np.floor(np.amin(min_mjd))
          for dmx in dmx_data:
               dmx['mjd0'] = dmx['mjd'] - mjdint
      
     if(xval=='year'):
          for dmx in dmx_data:
               # date_out = [mjd.mjdtodate(m) for m in width['mjd']]
               dmx['year'] = [mjd.mjdtoyear(m) for m in dmx['mjd']]

# Set up plot limits now
##     xmin = np.amin(dmx_data[0][xval])-0.003
##     xmax = np.amax(dmx_data[0][xval])+0.003
##     ymin = np.amin(dmx_data[0]['width'] - dmx_data[0]['werr'])
##     ymax = np.amax(dmx_data[len(dmx_data)-1]['width'] + \
##                         dmx_data[len(dmx_data)-1]['werr'])
##     xspan = abs(xmax - xmin)
##     yspan = abs(ymax - ymin)
          
# Set up the plot:
     fig = plt.figure(figsize=canvassize)
     ax = fig.add_axes([0.15, 0.1, 0.8, 0.85])
     ax.xaxis.set_tick_params(labelsize=16)
     ax.yaxis.set_tick_params(labelsize=16)
##     if(xlim==None):
##          ax.set_xlim(xmin - 0.01*xspan, xmax + 0.01*xspan)
##     else:
##          ax.set_xlim(xlim)
##     if(ylim==None):
##          ax.set_ylim(ymin - 0.01*yspan, ymax + 0.02*yspan)
##     else:
##          ax.set_ylim(ylim)

     if (xlabel):          
          if(xval=='serial'):               
               ax.set_xlabel('Serial number', fontsize=18)
          elif(xval=='mjd'):
               ax.set_xlabel('MJD', fontsize=18)
          elif(xval=='mjd0'):
               ax.set_xlabel('MJD - {:d}'.format(int(mjdint)), fontsize=18)
          elif(xval=='year'):
               # Set formatting for years so that they have %d formatting:
               xmajorFormatter = FormatStrFormatter('%d')
               ax.set_xlabel('Year', fontsize=18)
               ax.xaxis.set_major_formatter(xmajorFormatter)
               
     if (ylabel):
          ax.set_ylabel('Delta DM (pc cm$^{-3}$)', fontsize=18)


     if(not xticks):
          for tick in ax.xaxis.get_major_ticks():
               tick.label1On = False
               tick.label2On = False

     if(not yticks):
          for tick in ax.yaxis.get_major_ticks():
               tick.label1On = False
               tick.label2On = False

     for i_dmx in range(len(dmx_data)):
          dmx = dmx_data[i_dmx]
          # Get colours
          if (type(colour) is list):
               clr = colour[i_dmx]
          else:
               clr = colour
          if (type(mf_colour) is list):
               mf_clr = mf_colour[i_dmx]
          else:
               mf_clr = mf_colour
          if (type(me_colour) is list):
               me_clr = me_colour[i_dmx]
          else:
               me_clr = me_colour
          #ax.plot(res[xval], res['res'], 'o', markersize=msize, color=col)

          if(gridlines!=None):
               for ycoord in gridlines:
                    ax.axhline(ycoord, linestyle='--', color='black', \
                                    linewidth=0.4)

          # Change to appropriate units: 
          # Set formatting for degrees so that they have correct 
          # formatting:
          ymajorFormatter = FormatStrFormatter('%6.3f')
          ax.yaxis.set_major_formatter(ymajorFormatter)
          y_plot = dmx['delta_dm']
          yerr_plot = dmx['delta_dm_err']
          
          if(i_dmx==0):
               xmin = np.amin(dmx[xval])
               xmax = np.amax(dmx[xval])
               ymin = np.amin(y_plot-yerr_plot)
               ymax = np.amax(y_plot+yerr_plot)               
          else:
               xmin = np.amin(np.append(dmx[xval], xmin))
               xmax = np.amax(np.append(dmx[xval], xmax))
               ymin = np.amin(np.append(y_plot-yerr_plot, ymin))
               ymax = np.amax(np.append(y_plot+yerr_plot, ymax))

          xspan = abs(xmax - xmin)
          yspan = abs(ymax - ymin)

# Overplot error bars.  Use fmt=None to tell it not to plot points:
          ax.plot(dmx[xval], y_plot, marker=sym, markersize=symsize,
                  linestyle='None',
                  color=clr, mfc=mf_clr, mec=me_clr)
          ax.errorbar(dmx[xval], y_plot, yerr=yerr_plot, 
                      capsize=csize, fmt=None, ecolor=clr, 
                      markersize=msize)
     if(xlim==None):
          ax.set_xlim(xmin - 0.025*xspan, xmax + 0.025*xspan)
     else:
          ax.set_xlim(xlim)
     if(ylim==None):
          ax.set_ylim(ymin - 0.1*yspan, ymax + 0.1*yspan)
     else:
          ax.set_ylim(ylim)

# Figure text must be a list of tuples: [(x, y, text), (x, y, text), ...]
     if(figtext!=None):
          for txt in figtext:
               ax.text(txt[0], txt[1], txt[2], fontsize=10, \
                            horizontalalignment='center', \
                            verticalalignment='center')

     return
Ejemplo n.º 3
0
def plot_resid(resid_data, info_plot=None, canvassize=None, 
               axis_limits=None, binsize=None, resoffset=0.,
               preres=False, xunits='year', yunits='us', 
               xticks=True, yticks=True, xlabel=True, ylabel=True, 
               sym='o', symsize=1.8, colour=None, csize=1.3, 
               xlim=None, ylim=None, figtext=None, gridlines=None,
               ticklabelsize=18, axislabelsize=18):
#               mjdlim=None, yearlim=None, orbphaselim=None):

     
     
     
     # First determine if certain keywords are one-dimensional.  If so, 
     # will need to list-ify them
     if(type(resid_data) is not list):
          resid_data = [resid_data]
     else:
          print 'len(resid_data) = ', len(resid_data)

     if(type(xunits) is list):
          # Case where we want to plot multiple xunits for same 
          # residual data set
          if(len(xunits) > 1 & len(resid_data) == 1):
               resid_data = resid_data*len(xunits)
          # Case where we want multiple residual data sets with same 
          # xunits.
          elif(len(xunits) == 1 & len(resid_data) > 1):
               xunits = xunits*len(resid_data)
          # Otherwise, will need lengths of xunits and resid_data to
          # be equal
          elif(len(xunits) != len(resid_data)):
               print 'plot_resid ERROR: xunits and resid_data must have same dimensions if they are both > 1 in length'
     else:
          # Case where we want one set of xunits, no matter the amount
          # of residual data sets
          xunits = [xunits]*len(resid_data)

     # Deal with yunits the same as xunits:
     if(type(yunits) is list):
          # Case where we want to plot multiple yunits for same 
          # residual data set
          if(len(yunits) > 1 & len(resid_data) == 1):
               resid_data = resid_data*len(yunits)
          # Case where we want multiple residual data sets with same 
          # yunits.
          elif(len(yunits) == 1 & len(resid_data) > 1):
               yunits = yunits*len(resid_data)
          # Otherwise, will need lengths of yunits and resid_data to
          # be equal
          elif(len(yunits) != len(resid_data)):
               print 'plot_resid ERROR: yunits and resid_data must have same dimensions if they are both > 1 in length'
     else:
          # Case where we want one set of yunits, no matter the amount
          # of residual data sets
          yunits = [yunits]*len(resid_data)

     # For xlabels, xticks, ylabel, yticks:
     #     - if they are not passed as lists, then default will be to 
     #       make one label/tick label for x,  and one label for y axes.
     #     - if they are passed as lists, then it is up to the user to
     #       ensure that they correctly correspond to the residual 
     #       data list

     if(type(xlabel) is list):
          if(len(xlabel) != len(resid_data)):
               print 'plot_resid ERROR: xlabel list must have same dimensions as resid_data'
               exit()
     else:
          xlabel = [xlabel]*len(resid_data)

     if(type(ylabel) is list):
          if(len(ylabel) != len(resid_data)):
               print 'plot_resid ERROR: ylabel list must have same dimensions as resid_data'
               exit()
     else:
          ylabel = [ylabel]*len(resid_data)
          


     if (type(xticks) is list):
          if(len(xticks) != len(resid_data)):
               print 'plot_resid ERROR: xticks list must have same dimensions as resid_data'
               exit()
     else:
          xticks = [xticks]*len(resid_data)

     if (type(yticks) is list):
          if(len(yticks) != len(resid_data)):
               print 'plot_resid ERROR: yticks list must have same dimensions as resid_data'
               exit()
     else:
          yticks = [yticks]*len(resid_data)
    
     # Default is to plot same info IDs on each plot in the same way.
     # However, may have different data sets with different info IDs; in
     # such a case, ensure dimensions are the same as resid_data
     if(info_plot != None):
          if(type(info_plot[0]) is list):
               if(len(info_plot) == 1):   
                    # replicate this list to have the same dimensions 
                    # as resid_data (need square brackets to have 
                    # n *lists*; otherwise will continue on same list)
                    info_plot = [info_plot]*len(resid_data)
               elif(len(info_plot) != len(resid_data)):
                    print 'plot_resid ERROR: info_plot keyword much has same dimensions as resid_data.'
               exit()
          else:
               info_plot = [[info_plot]]*len(resid_data)

     
     # Set axis limits.  If they are not given, default behaviour is to 
     # divide evenly in vertical direction
     # We assume that for default plotting, first data set is plotted
     # on the bottome, and last is on the top opf the canvas
     # Based on a full plot being [0.12, 0.1, 0.8, 0.85]
     if(axis_limits is None):
          axis_limits = []
          for i_plot in np.arange(len(resid_data)):
               x1 = 0.12
               xwidth = 0.8
               ywidth=0.85/len(resid_data)
               y1 = 0.12 + i_plot*ywidth
               axis_limits.append([x1, y1, xwidth, ywidth])
     else:
          if(type(axis_limits) is list):
               if(len(axis_limits) == 1 & len(resid_data)>1):
                    axis_limits = [axis_limits]*len(resid_data)
               elif(len(axis_limits) != len(resid_data)):
                    print 'plot_resid ERROR: If specifying multiple axis_limits to keyword, it must be a list of same dimensions as resid_data'
               exit()
          else:
               axis_limits = [axis_limits]
          
               
     


     # This doesn't change throughout, so set it up now:
     possible_xunits = ['mjd', 'mjd0', 'year', 'orbphase', 'serial']
     
     # Set up conversion factors for possible yunits
     yconv = {'s':1e-6, 'ms':1e-3, 'us':1, 'ns':1e3}


# Set up the plot:
     fig = plt.figure(figsize=canvassize)

     ax = []
     for i_plot in np.arange(len(resid_data)):

          res_data = resid_data[i_plot]

          ax.append(fig.add_axes(axis_limits[i_plot]))
          ax[i_plot].xaxis.set_tick_params(labelsize=16)
          ax[i_plot].yaxis.set_tick_params(labelsize=16)



          if(possible_xunits.count(xunits[i_plot]) == 0):
               print "There is no value "+xunits[i_plot]+ \
                    " in the data that we can plot on x axis. Exiting."
               exit()


     # mjd0 will subtract the nearest 100 of the smallest mjd value from the 
     # mjd arrays
          if(xunits[i_plot]=='mjd0'):
              min_mjd = np.amin(res_data['mjd'])
              mjdint = np.floor(min_mjd)
              res_data['mjd0'] = res_data['mjd'] - mjdint

          if(xunits[i_plot]=='year'):
              res_data['year'] = \
                   np.array([mjd.mjdtoyear(m) for m in res_data['mjd']])

          # Convert units based on yunits keyword
          res_data['res'] *= yconv[yunits[i_plot]]
          res_data['reserr'] *= yconv[yunits[i_plot]]


#         xmax = np.amax(res_data[xunits[i_plot]])+0.003
#          ymin = np.amin(res_data['res'] - res_data['reserr'])
#          ymax = np.amax(res_data['res'] + res_data['reserr'])
#          xspan = abs(xmax - xmin)
#          yspan = abs(ymax - ymin)

#          if(xlim==None):
#               xlim=(xmin, xmax)
#          if(ylim==None):
#               ylim=(ymin, ymax)

          if (xlabel[i_plot]):          
               if(xunits[i_plot]=='serial'):               
                    ax[i_plot].set_xlabel('Serial number', fontsize=axislabelsize, labelpad=12)
               elif(xunits[i_plot]=='orbphase'):
                    ax[i_plot].set_xlabel('Orbital phase', fontsize=axislabelsize, labelpad=12)           
               elif(xunits[i_plot]=='mjd'):
                    ax[i_plot].set_xlabel('MJD', fontsize=axislabelsize, labelpad=12)
               elif(xunits[i_plot]=='mjd0'):
                    ax[i_plot].set_xlabel('MJD - {:d}'.format(int(mjdint)), 
                                  fontsize=axislabelsize, labelpad=12)
               elif(xunits[i_plot]=='year'):
                    ax[i_plot].set_xlabel('Year', fontsize=axislabelsize, labelpad=12)
                    xmajorFormatter = FormatStrFormatter('%4d')
                    ax[i_plot].xaxis.set_major_formatter(xmajorFormatter)


          if (ylabel[i_plot]):
               if(yunits[i_plot]=='s'):
                    ax[i_plot].set_ylabel('Residuals (s)', fontsize=axislabelsize, labelpad=6)
               if(yunits[i_plot]=='ms'):
                    ax[i_plot].set_ylabel('Residuals (ms)', fontsize=axislabelsize, labelpad=6)
               if(yunits[i_plot]=='us'):
                    ax[i_plot].set_ylabel('Residuals ($\mu$s)', fontsize=axislabelsize, labelpad=6)
               if(yunits[i_plot]=='ns'):
                    ax[i_plot].set_ylabel('Residuals (ns)', fontsize=axislabelsize, labelpad=6)


          ax[i_plot].tick_params(labelsize=ticklabelsize, pad=10)

          if(xticks[i_plot]):
               for tick in ax[i_plot].xaxis.get_major_ticks():
                    tick.label1On = True
                    tick.label2On = False #top ticks
          else:
               for tick in ax[i_plot].xaxis.get_major_ticks():
                    tick.label1On = False
                    tick.label2On = False 


          if(yticks[i_plot]):
               for tick in ax[i_plot].yaxis.get_major_ticks():
                    tick.label1On = True
                    tick.label2On = False # right ticks
          else:
               for tick in ax[i_plot].yaxis.get_major_ticks():
                    tick.label1On = False
                    tick.label2On = False


          if(gridlines!=None):
               for ycoord in gridlines:
                    ax[i_plot].axhline(ycoord, linestyle='--', color='black', \
                                    linewidth=0.4)

        

     # To keep track of min and max x and y
          x_min = []
          x_max = []
          y_min = []
          y_max = []
     # Set things up for plotting, especially if there are many infos 
     # and colours:
          if (res_data['info'] != None):
               # Find common info numbers between the command-line 
               # requested info numbers, and those in the info file itself, 
               # and only plot these:
               # If no info array given, just plot all of them
               if(info_plot==None): 
                    info_common = res_data['info_val']
               else:
                    # call this something else rather than info_plot since 
                    # list may have different lengths in each element, 
                    # depending on what user wants to plot
                    # first, cast as numpy array
                    info_this_plot=np.array(info_plot[i_plot])
                    # This will work but sorts the result... need way to 
                    # preserve original order...
                    info_common = \
                        np.intersect1d(np.unique(info_this_plot), 
                                       res_data['info_val'])
                    print 'info_common = ', info_common
               # Set up colours depending on number of info numbers          
               n_info = len(info_common)
               for i_info in np.arange(n_info):
                    info_condition = \
                        res_data['info']==info_common[i_info]
                    info_ind = np.where(info_condition)
                    res_x = res_data[xunits[i_plot]][info_ind]
               # np.where(res_data['info']==res_data['info_val'][i_info])
                    res_y = res_data['res'][info_ind] + \
                            float(i_info)*resoffset # zero by default
                    res_err = res_data['reserr'][info_ind]
                    if(binsize!=None):
                         res_x, res_y, res_err = bin_resids(res_x, res_y, 
                                                          res_err, 
                                                          binsize=binsize, 
                                                          weigh_x=True)

                    # Keep track of min and max x and y values
                    x_min.append(np.amin(res_x))#[res_x > xlim[0]]))
                    x_max.append(np.amax(res_x))#[res_x < xlim[1]]))
                    y_min.append(np.amin(res_y - res_err))#[res_y > ylim[0]]))
                    y_max.append(np.amax(res_y + res_err))#[res_y < ylim[1]]))
                    if (n_info==1):
                         clr = 'black'
                    else:
                         if (colour==None):
                             clr = cm.jet(float(i_info)/float(n_info-1)) 
                         else:
                             if (len(colour) == n_info):
                                 clr = colour[i_info]
                                 print 'COLOUR = ', clr
                             else:
                                 print 'Error: Must use same number of colours as info numbers.  Exiting'
                                 return
                    ax[i_plot].plot(res_x, res_y, sym, 
                                    markersize=symsize, 
                                    markeredgecolor=clr, 
                                    markerfacecolor=clr)
                    ax[i_plot].errorbar(res_x, res_y, yerr=res_err, 
                                        fmt=None, capsize=csize, 
                                        ecolor=clr)
          else:
               n_info = 1
               if (colour==None):
                   clr = 'black'
               else:
                   clr = colour
               res_x = res_data[xunits[i_plot]]
               res_y = res_data['res']
               res_err = res_data['reserr']
               if(binsize!=None):
                    res_x, res_y, res_err = bin_resids(res_x, res_y, 
                                                     res_err, 
                                                     binsize=binsize,
                                                     weigh_x=True)

               ax[i_plot].plot(res_x, res_y, sym, markersize=symsize, \
                            markeredgecolor=clr, markerfacecolor=clr)
               ax[i_plot].errorbar(res_x, res_y, yerr=res_err, \
                                fmt=None, capsize=csize, ecolor=clr)
               x_min.append(np.amin(res_x))#[res_x > xlim[0]]))
               x_max.append(np.amax(res_x))#[res_x < xlim[1]]))
               y_min.append(np.amin(res_y - res_err))#[res_y > ylim[0]]))
               y_max.append(np.amax(res_y + res_err))#[res_y < ylim[1]]))

     # Now set limits based on min and max *plotted* data:
          x_min = np.amin(np.array(x_min))
          x_max = np.amax(np.array(x_max))
          y_min = np.amin(np.array(y_min))
          y_max = np.amax(np.array(y_max))

          if(xlim==None):
              ax[i_plot].set_xlim(x_min, x_max)
          else:
              ax[i_plot].set_xlim(xlim)
          
          if(ylim==None):
              ax[i_plot].set_ylim(y_min, y_max)
          else:
              ax[i_plot].set_ylim(ylim)
               

     # Adjust limits to have about 5% breathing room of the plotted limits 
     # on either side:
          #print 'GOT TO HERE'
          x_lim = ax[i_plot].get_xlim()
          x_buffer = 0.025*(x_lim[1]-x_lim[0])
          ax[i_plot].set_xlim(x_lim[0]-x_buffer, x_lim[1]+x_buffer)
          y_lim = ax[i_plot].get_ylim()
          y_buffer = 0.05*(y_lim[1]-y_lim[0])
          ax[i_plot].set_ylim(y_lim[0]-y_buffer, y_lim[1]+y_buffer)

          #print 'ylim = ', y_lim
          # Figure text must be a list of tuples: [(x, y, text), (x, y, text), ...]
          if(figtext!=None):
               for txt in figtext:
                    ax[i_plot].text(txt[0], txt[1], txt[2], 
                                    fontsize=10, \
                                    horizontalalignment='center', \
                                    verticalalignment='center')
Ejemplo n.º 4
0
def main():
     input_files = argv[1:]

# Parse command line to get all width files we wish to plot
     width_file = []
     for file_name in input_files:
          # For some reason this works and ".append()" doesn't:
          width_file[len(width_file):] = glob.glob(file_name)

     n_subplots = len(width_file)
     print "N_SUBPLOTS = ", n_subplots
     
# Set up the plot:
     fig = plt.figure()#figsize=canvassize)

     fig_top = 0.95
     fig_bottom = 0.13
     fig_left = 0.13
     fig_right = 0.95

     for i_w in np.arange(n_subplots):
          print "I_W = ", i_w
#          width_data =[]
#          for wfile in width_files:
          width_data = read_widths(width_file[i_w])
#     print res_data['mjd']

# Now plot width in degrees vs time in years for each subplot:
#          for width in width_data:
                    # date_out = [mjd.mjdtodate(m) for m in width['mjd']]
          width_data['year'] = [mjd.mjdtoyear(m) for m in width_data['mjd']]
          width_data['width'] *= 360.
          width_data['werr']  *= 360.
                    # width['year'] = [d.year + d.day/365. + \
                        #                    d.hour/(365.*24.) + \
                        #                    d.minute/(365.*24.*60.) + \
                        #                    d.second/(365.*24.*60.*60.) \
               #                    for d in date_out]
# Set up plot limits now
          xmin = np.amin(width_data['year'])-0.25
          xmax = np.amax(width_data['year'])+0.25
          ymin = np.amin(width_data['width'] - width_data['werr'])
          ymax = np.amax(width_data['width'] + width_data['werr'])
          xspan = abs(xmax - xmin)
          yspan = abs(ymax - ymin)
          
#          ax = fig.add_subplot(n_subplots, 1, i_w+1)
          if(i_w == 0):
               max_yspan = yspan
          else:
# Keep track of max yspan to later set all plots to have same scale
               if(yspan > max_yspan):
                    max_yspan = yspan
          ax = fig.add_axes([fig_left, \
                    fig_bottom+(fig_top-fig_bottom)*(float(n_subplots-(i_w+1))/float(n_subplots)),\
                    fig_right-fig_left, \
                    (fig_top-fig_bottom)*(1./float(n_subplots))])
#          ax.set_ylabel('Pulse width (degrees)')
          ax.set_xlim(xmin, xmax)
# Set y limits so that all plots have same scale
          ax.set_ylim(0.5*(ymin+ymax)-0.5*max_yspan, \
                           0.5*(ymin+ymax)+0.5*max_yspan)
          if(i_w < n_subplots-1):
               ax.xaxis.set_ticklabels([])
          else:
               xMajorFormatter = FormatStrFormatter('%d')
               ax.xaxis.set_major_formatter(xMajorFormatter)
          yMajorFormatter = FormatStrFormatter('%.1f')
          ax.yaxis.set_major_formatter(yMajorFormatter)
          ax.yaxis.set_major_locator(MaxNLocator(5, prune='both'))


# Now plot the widths
          ax.plot(width_data['year'], width_data['width'], 'o')
          ax.errorbar(width_data['year'], width_data['width'], \
                           width_data['werr'], fmt=None)
#          plot_widths(width_data, yunits='deg')


#  Finally, put axis labels for entire figure, and ensure that there aren't corner tick labels... maybe adjust y limits to avoid this?  There was a way of setting it, but maybe too much waster of time to find it.
     fig.text(0.5*(fig_left+fig_right), 0.06, 'Year', fontsize=16,  ha='center', va='center')
     fig.text(0.04, 0.5*(fig_top+fig_bottom), 'Pulse width (deg)', fontsize=16,  ha='center', va='center', rotation='vertical')
     #fig.text(0.1, 0.5, )

     plot_file = 'widths_multi.png'
          
     plt.savefig(plot_file)