def register(cls, function, column, corr, ms, minmax=None, ncol=None, subset=None, minmax_cache=None): """ Registers a data axis, which ultimately ends up as a column in the assembled dataframe. For multiple plots, we want to reuse the same information (assuming the same clipping limits, etc.) so we have a dictionary of axis definitions here. Columns selects a column to operate on. Function selects a mapping (see datamappers below). Corr selects a correlation (or a Stokes product such as I, Q,...) minmax sets axis clipping levels ncol discretizes the axis into N colours between min and max minmax_cache provides a dict of cached min/max values, which will be looked up via the label, if minmax is not explicitly set """ # form up label label = "{}_{}_{}".format(col_to_label(column or ''), function, corr) minmax = tuple(minmax) if minmax is not None else (None, None) key = label, minmax, ncol # see if this axis definition already exists, else create new one if key in cls.all_axes: return cls.all_axes[key] else: # see if minmax should be loaded if (minmax is None or tuple(minmax) == (None, None)) and minmax_cache and label in minmax_cache: log.info(f"loading {label} min/max from cache") minmax = minmax_cache[label] label0, i = label, 0 while label in cls.all_labels: i += 1 label = f"{label}_{i}" cls.all_labels.add(label) axis = cls.all_axes[key] = DataAxis(column, function, corr, ms, minmax, ncol, label, subset=subset) return axis
def render_single_plot(df, subset, xdatum, ydatum, adatum, ared, cdatum, pngname, title, xlabel, ylabel): """Renders a single plot. Make this a function since we might call it in parallel""" log.info(f": rendering {pngname}") normalize = options.norm if normalize == "auto": normalize = "log" if cdatum is not None else ("eq_hist" if adatum is None else 'linear') if options.profile: context = dask.diagnostics.ResourceProfiler else: context = nullcontext with context() as profiler: result = data_plots.create_plot(df, subset, xdatum, ydatum, adatum, ared, cdatum, cmap=cmap, bmap=bmap, dmap=dmap, normalize=normalize, min_alpha=options.min_alpha, saturate_alpha=options.saturate_alpha, saturate_percentile=options.saturate_perc, xlabel=xlabel, ylabel=ylabel, title=title, pngname=pngname, extra_markup=extra_markup, minmax_cache=minmax_cache, options=options) if result: log.info(f' : wrote {pngname}') if profiler is not None: profile_file = os.path.splitext(pngname)[0] + ".prof.html" dask.diagnostics.visualize(profiler, file_path=profile_file, show=False, save=True) log.info(f' : wrote profiler info to {profile_file}')
def get_colormap(cmap_name): cmap = getattr(colorcet, cmap_name, None) if cmap: log.info(f"using colourmap colorcet.{cmap_name}") return cmap cmap = getattr(cmasher, cmap_name, None) if cmap: log.info(f"using colourmap cmasher.{cmap_name}") else: cmap = getattr(matplotlib.cm, cmap_name, None) if cmap is None: raise ValueError(f"unknown colourmap {cmap_name}") log.info(f"using colourmap matplotplib.cm.{cmap_name}") return [ f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}" for r,g,b in cmap.colors ]
def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, chanslice, subset, noflags, noconj, iter_field, iter_spw, iter_scan, join_corrs=False, row_chunk_size=100000): ms_cols = {'ANTENNA1', 'ANTENNA2'} if not noflags: ms_cols.update({'FLAG', 'FLAG_ROW'}) # get visibility columns for axis in DataAxis.all_axes.values(): ms_cols.update(axis.columns) # get MS data msdata = daskms.xds_from_ms(msinfo.msname, columns=list(ms_cols), group_cols=group_cols, taql_where=mytaql, chunks=dict(row=row_chunk_size)) log.info(f': Indexing MS and building dataframes (chunk size is {row_chunk_size})') np = 0 # number of points to plot # output dataframes, indexed by (field, spw, scan, antenna, correlation) # If any of these axes is not being iterated over, then the index is None output_dataframes = OrderedDict() # # make prototype dataframe # import pandas # # # iterate over groups for group in msdata: ddid = group.DATA_DESC_ID # always present fld = group.FIELD_ID # always present if fld not in subset.field or ddid not in subset.spw: log.debug(f"field {fld} ddid {ddid} not in selection, skipping") continue scan = getattr(group, 'SCAN_NUMBER', None) # will be present if iterating over scans # TODO: antenna iteration. None forces no iteration, for now antenna = None # always read flags -- easier that way flag = group.FLAG if not noflags else None flag_row = group.FLAG_ROW if not noflags else None baselines = group.ANTENNA1*len(msinfo.antenna) + group.ANTENNA2 freqs = chan_freqs[ddid] chans = xarray.DataArray(range(len(freqs)), dims=("chan",)) wavel = freq_to_wavel(freqs) extras = dict(chans=chans, freqs=freqs, wavel=wavel, rows=group.row, baselines=baselines) nchan = len(group.chan) if flag is not None: flag = flag[dict(chan=chanslice)] nchan = flag.shape[1] shape = (len(group.row), nchan) datums = OrderedDict() for corr in subset.corr.numbers: # make dictionary of extra values for DataMappers extras['corr'] = corr # loop over datums to be computed for axis in DataAxis.all_axes.values(): value = datums[axis.label][-1] if axis.label in datums else None # a datum was already computed? if value is not None: # if not joining correlations, then that's the only one we'll need, so continue if not join_corrs: continue # joining correlations, and datum has a correlation dependence: compute another one if axis.corr is None: value = None if value is None: value = axis.get_value(group, corr, extras, flag=flag, flag_row=flag_row, chanslice=chanslice) # reshape values of shape NTIME to (NTIME,1) and NFREQ to (1,NFREQ), and scalar to (NTIME,1) if value.ndim == 1: timefreq_axis = axis.mapper.axis or 0 assert value.shape[0] == shape[timefreq_axis], \ f"{axis.mapper.fullname}: size {value.shape[0]}, expected {shape[timefreq_axis]}" shape1 = [1,1] shape1[timefreq_axis] = value.shape[0] value = value.reshape(shape1) if timefreq_axis > 0: value = da.broadcast_to(value, shape) log.debug(f"axis {axis.mapper.fullname} has shape {value.shape}") # else 2D value better match expected shape else: assert value.shape == shape, f"{axis.mapper.fullname}: shape {value.shape}, expected {shape}" datums.setdefault(axis.label, []).append(value) # if joining correlations, stick all elements together. Otherwise, we'd better have one per label if join_corrs: datums = OrderedDict({label: da.concatenate(arrs) for label, arrs in datums.items()}) else: assert all([len(arrs) == 1 for arrs in datums.values()]) datums = OrderedDict({label: arrs[0] for label, arrs in datums.items()}) # broadcast to same shape, and unravel all datums datums = OrderedDict({ key: arr.ravel() for key, arr in zip(datums.keys(), da.broadcast_arrays(*datums.values()))}) # if any axis needs to be conjugated, double up all of them if not noconj and any([axis.conjugate for axis in DataAxis.all_axes.values()]): for axis in DataAxis.all_axes.values(): if axis.conjugate: datums[axis.label] = da.concatenate([datums[axis.label], -datums[axis.label]]) else: datums[axis.label] = da.concatenate([datums[axis.label], datums[axis.label]]) labels, values = list(datums.keys()), list(datums.values()) np += values[0].size # now stack them all into a big dataframe rectype = [(axis.label, numpy.int32 if axis.nlevels else numpy.float32) for axis in DataAxis.all_axes.values()] recarr = da.empty_like(values[0], dtype=rectype) ddf = dask_df.from_array(recarr) for label, value in zip(labels, values): ddf[label] = value # now, are we iterating or concatenating? Make frame key accordingly dataframe_key = (fld if iter_field else None, ddid if iter_spw else None, scan if iter_scan else None, antenna) # do we already have a frame for this key ddf0 = output_dataframes.get(dataframe_key) if ddf0 is None: log.debug(f"first frame for {dataframe_key}") output_dataframes[dataframe_key] = ddf else: log.debug(f"appending to frame for {dataframe_key}") output_dataframes[dataframe_key] = ddf0.append(ddf) # convert discrete axes into categoricals if data_mappers.USE_COUNT_CAT: categorical_axes = [axis.label for axis in DataAxis.all_axes.values() if axis.nlevels] if categorical_axes: log.info(": counting colours") for key, ddf in list(output_dataframes.items()): output_dataframes[key] = ddf.categorize(categorical_axes) log.info(": complete") return output_dataframes, np
def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, normalize, xlabel, ylabel, title, pngname, options=None): figx = options.xcanvas / 60 figy = options.ycanvas / 60 bgcol = "#" + options.bgcol.lstrip("#") xaxis = xdatum.label yaxis = ydatum.label aaxis = adatum and adatum.label caxis = cdatum and cdatum.label color_key = ncolors = color_mapping = color_labels = agg_alpha = raster_alpha = None xmin, xmax = xdatum.minmax ymin, ymax = ydatum.minmax canvas = datashader.Canvas(options.xcanvas, options.ycanvas, x_range=[xmin, xmax] if xmin is not None else None, y_range=[ymin, ymax] if ymin is not None else None) if aaxis is not None: agg_alpha = getattr(datashader.reductions, ared, None) if agg_alpha is None: raise ValueError(f"unknown alpha reduction function {ared}") agg_alpha = agg_alpha(aaxis) ared = ared or 'count' if cdatum is not None: if agg_alpha is not None and not USE_REDUCE_BY: log.debug(f'rasterizing alpha channel using {ared}(aaxis)') raster_alpha = canvas.points(ddf, xaxis, yaxis, agg=agg_alpha) if data_mappers.USE_COUNT_CAT: color_bins = [int(x) for x in getattr(ddf.dtypes, caxis).categories] log.debug(f'colourizing with count_cat, {len(color_bins)} bins') if USE_REDUCE_BY and agg_alpha: agg = datashader.by(caxis, agg_alpha) else: agg = datashader.count_cat(caxis) else: color_bins = list(range(cdatum.nlevels)) log.debug(f'colourizing with count_integer, {len(color_bins)} bins') if USE_REDUCE_BY and agg_alpha: agg = by_integers(caxis, agg_alpha, cdatum.nlevels) else: agg = count_integers(caxis, cdatum.nlevels) raster = canvas.points(ddf, xaxis, yaxis, agg=agg) non_empty = numpy.array(raster.any(axis=(0, 1))) if not non_empty.any(): log.info(": no valid data in plot. Check your flags and/or plot limits.") return None # true if axis is continuous discretized if cdatum.discretized_delta is not None: # color labels are bin centres bin_centers = [cdatum.discretized_bin_centers[i] for i in color_bins] # map to colors pulled from 256 color map color_key = [bmap[(i*256)//cdatum.nlevels] for i in color_bins] color_labels = list(map(str, bin_centers)) log.info(f": shading using {len(color_bins)} colors (bin centres are {' '.join(color_labels)})") # else a discrete axis else: # discard empty bins non_empty = numpy.where(non_empty)[0] raster = raster[..., non_empty] # just use bin numbers to look up a color directly color_bins = [color_bins[i] for i in non_empty] color_key = [dmap[bin] for bin in color_bins] # the numbers may be out of order -- reorder for color bar purposes bin_color = sorted(zip(color_bins, color_key)) if cdatum.discretized_labels and len(cdatum.discretized_labels) <= cdatum.nlevels: color_labels = [cdatum.discretized_labels[bin] for bin, _ in bin_color] else: color_labels = [str(bin) for bin, _ in bin_color] color_mapping = [col for _, col in bin_color] log.info(f": rendering using {len(color_bins)} colors (values {' '.join(color_labels)})") if raster_alpha is not None: amin, amax = numpy.nanmin(raster_alpha), numpy.nanmax(raster_alpha) raster = raster*(raster_alpha-amin)/(amax-amin) log.info(f": adjusting alpha (alpha raster was {amin} to {amax})") img = datashader.transfer_functions.shade(raster, color_key=color_key, how=normalize) else: log.debug(f'rasterizing using {ared}') raster = canvas.points(ddf, xaxis, yaxis, agg=agg_alpha) if not raster.data.any(): log.info(": no valid data in plot. Check your flags and/or plot limits.") return None log.debug('shading') img = datashader.transfer_functions.shade(raster, cmap=cmap, how=normalize) if options.spread_pix: img = datashader.transfer_functions.dynspread(img, options.spread_thr, max_px=options.spread_pix) log.info(f": spreading ({options.spread_thr} {options.spread_pix})") rgb = holoviews.RGB(holoviews.operation.datashader.shade.uint32_to_uint8_xr(img)) log.debug('done') # Set plot limits based on data extent or user values for axis labels data_xmin = numpy.min(raster.coords[xaxis].values) data_xmax = numpy.max(raster.coords[xaxis].values) data_ymin = numpy.min(raster.coords[yaxis].values) data_ymax = numpy.max(raster.coords[yaxis].values) xmin = data_xmin if xmin is None else xdatum.minmax[0] xmax = data_xmax if xmax is None else xdatum.minmax[1] ymin = data_ymin if ymin is None else ydatum.minmax[0] ymax = data_ymax if ymax is None else ydatum.minmax[1] log.debug('rendering image') def match(artist): return artist.__module__ == 'matplotlib.text' fig = pylab.figure(figsize=(figx, figy)) ax = fig.add_subplot(111, facecolor=bgcol) ax.imshow(X=rgb.data, extent=[data_xmin, data_xmax, data_ymin, data_ymax], aspect='auto', origin='lower') ax.set_title("\n".join(textwrap.wrap(title, 90)), loc='left') ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) # ax.plot(xmin,ymin,'.',alpha=0.0) # ax.plot(xmax,ymax,'.',alpha=0.0) dx, dy = xmax - xmin, ymax - ymin ax.set_xlim([xmin - dx/100, xmax + dx/100]) ax.set_ylim([ymin - dy/100, ymax + dy/100]) # set fontsize on everything rendered so far for textobj in fig.findobj(match=match): textobj.set_fontsize(options.fontsize) # colorbar? if color_key: import matplotlib.colors # discrete axis if color_mapping is not None: norm = matplotlib.colors.Normalize(-0.5, len(color_bins)-0.5) ticks = numpy.arange(len(color_bins)) colormap = matplotlib.colors.ListedColormap(color_mapping) # discretized axis else: norm = matplotlib.colors.Normalize(cdatum.minmax[0], cdatum.minmax[1]) colormap = matplotlib.colors.ListedColormap(color_key) # auto-mark colorbar, since it represents a continuous range of values ticks = None cb = fig.colorbar(matplotlib.cm.ScalarMappable(norm=norm, cmap=colormap), ax=ax, ticks=ticks) # adjust ticks for discrete axis if color_mapping is not None: rot = 0 # adjust fontsize for number of labels fs = max(options.fontsize*min(1, 32./len(color_labels)), 6) fontdict = dict(fontsize=fs) if max([len(lbl) for lbl in color_labels]) > 3 and len(color_labels) < 8: rot = 90 fontdict['verticalalignment'] ='center' cb.ax.set_yticklabels(color_labels, rotation=rot, fontdict=fontdict) fig.savefig(pngname, bbox_inches='tight') pylab.close() return pngname
def main(argv): clock_start = time.time() # default # of CPUs # --------------------------------------------------------------------------------------------------------------------------------------------- parser = argparse.ArgumentParser(description='Rapid Measurement Set plotting with dask-ms and datashader. Version {0:s}'.format(__version__)) parser.add_argument('ms', help='Measurement set') parser.add_argument("-v", "--version", action='version', version='{:s} version {:s}'.format(parser.prog, __version__)) group_opts = parser.add_argument_group('Plot types and data sources') group_opts.add_argument('-x', '--xaxis', dest='xaxis', action="append", help="""X axis of plot, e.g. "amp:CORRECTED_DATA" This recognizes all column names (also CHAN, FREQ, CORR, ROW, WAVEL, U, V, W, UV), and, for complex columns, keywords such as 'amp', 'phase', 'real', 'imag'. You can also specify correlations, e.g. 'DATA:phase:XX', and do two-column arithmetic with "+-*/", e.g. 'DATA-MODEL_DATA:amp'. Correlations may be specified by label, number, or as a Stokes parameter. The order of specifiers does not matter. """) group_opts.add_argument('-y', '--yaxis', dest='yaxis', action="append", help="""Y axis to plot. Must be given the same number of times as --xaxis. Note that X/Y can employ different columns and correlations.""") group_opts.add_argument('-a', '--aaxis', action="append", help="""Intensity axis. Can be none, or given once, or given the same number of times as --xaxis. If none, plot intensity (a.k.a. alpha channel) is proportional to density of points. Otherwise, a reduction function (see --ared below) is applied to the given values, and the result is used to determine intensity. """) group_opts.add_argument('--ared', action="append", help="""Alpha axis reduction function. Recognized reductions are count, any, sum, min, max, mean, std, first, last, mode. Default is mean.""") group_opts.add_argument('-c', '--colour-by', action="append", help="""Colour axis. Can be none, or given once, or given the same number of times as --xaxis. All columns and variations listed under --xaxis are available for colouring by.""") group_opts.add_argument('-C', '--col', metavar="COLUMN", dest='col', action="append", default=[], help="""Name of visibility column (default is DATA), if needed. This is used if the axis specifications do not explicitly include a column. For multiple plots, this can be given multiple times, or as a comma-separated list. Two-column arithmetic is recognized. """) group_opts.add_argument('--noflags', help='Enable to ignore flags. Default is to omit flagged data.', action='store_true') group_opts.add_argument('--noconj', help='Do not show conjugate points in u,v plots (default = plot conjugates).', action='store_true') group_opts = parser.add_argument_group('Plot axes setup') group_opts.add_argument('--xmin', action='append', help="""Minimum x-axis value (default = data min). For multiple plots, you can give this multiple times, or use a comma-separated list, but note that the clipping is the same per axis across all plots, so only the last applicable setting will be used. The list may include empty elements (or 'None') to not apply a clip.""") group_opts.add_argument('--xmax', action='append', help='Maximum x-axis value (default = data max).') group_opts.add_argument('--ymin', action='append', help='Minimum y-axis value (default = data min).') group_opts.add_argument('--ymax', action='append', help='Maximum y-axis value (default = data max).') group_opts.add_argument('--cmin', action='append', help='Minimum colouring value. Must be supplied for every non-discrete axis to be coloured by.') group_opts.add_argument('--cmax', action='append', help='Maximum colouring value. Must be supplied for every non-discrete axis to be coloured by.') group_opts.add_argument('--cnum', action='append', help=f'Number of steps used to discretize a continuous axis. Default is {DEFAULT_CNUM}.') group_opts = parser.add_argument_group('Options for multiple plots or combined plots') group_opts.add_argument('--iter-field', action="store_true", help='Separate plots per field (default is to combine in one plot)') group_opts.add_argument('--iter-antenna', action="store_true", help='Separate plots per antenna (default is to combine in one plot)') group_opts.add_argument('--iter-spw', action="store_true", help='Separate plots per spw (default is to combine in one plot)') group_opts.add_argument('--iter-scan', action="store_true", help='Separate plots per scan (default is to combine in one plot)') group_opts.add_argument('--iter-corr', action="store_true", help='Separate plots per correlation or Stokes (default is to combine in one plot)') group_opts = parser.add_argument_group('Data subset selection') group_opts.add_argument('--ant', default='all', help='Antennas to plot (comma-separated list of names, default = all)') group_opts.add_argument('--ant-num', help='Antennas to plot (comma-separated list of numbers, or a [start]:[stop][:step] slice, overrides --ant)') group_opts.add_argument('--baseline', default='all', help="Baselines to plot, as 'ant1-ant2' (comma-separated list, default = all)") group_opts.add_argument('--spw', default='all', help='Spectral windows (DDIDs) to plot (comma-separated list, default = all)') group_opts.add_argument('--field', default='all', help='Field ID(s) to plot (comma-separated list, default = all)') group_opts.add_argument('--scan', default='all', help='Scans to plot (comma-separated list, default = all)') group_opts.add_argument('--corr', default='all', help='Correlations or Stokes to plot, use indices or labels (comma-separated list, default = all)') group_opts.add_argument('--chan', help='Channel slice, as [start]:[stop][:step], default is to plot all channels') group_opts = parser.add_argument_group('Rendering settings') group_opts.add_argument('-X', '--xcanvas', type=int, help='Canvas x-size in pixels (default = %(default)s)', default=1280) group_opts.add_argument('-Y', '--ycanvas', type=int, help='Canvas y-size in pixels (default = %(default)s)', default=900) group_opts.add_argument('--norm', choices=['auto', 'eq_hist', 'cbrt', 'log', 'linear'], default='auto', help="Pixel scale normalization (default is 'log' when colouring, and 'eq_hist' when not)") group_opts.add_argument('--cmap', default='bkr', help="""Colorcet map used without --colour-by (default = %(default)s), see https://colorcet.holoviz.org""") group_opts.add_argument('--bmap', default='bkr', help='Colorcet map used when colouring by a continuous axis (default = %(default)s)') group_opts.add_argument('--dmap', default='glasbey_dark', help='Colorcet map used when colouring by a discrete axis (default = %(default)s)') group_opts.add_argument('--spread-pix', type=int, default=0, metavar="PIX", help="""Dynamically spread rendered pixels to this size""") group_opts.add_argument('--spread-thr', type=float, default=0.5, metavar="THR", help="""Threshold parameter for spreading (0 to 1, default %(default)s)""") group_opts.add_argument('--bgcol', dest='bgcol', help='RGB hex code for background colour (default = FFFFFF)', default='FFFFFF') group_opts.add_argument('--fontsize', dest='fontsize', help='Font size for all text elements (default = 20)', default=20) group_opts = parser.add_argument_group('Output settings') # can also use "plot-{msbase}-{column}-{corr}-{xfullname}-vs-{yfullname}", let's expand on this later group_opts.add_argument('--dir', help='Send all plots to this output directory') group_opts.add_argument('-s', '--suffix', help="suffix to be included in filenames, can include {options}") group_opts.add_argument('--png', dest='pngname', default="plot-{ms}{_field}{_Spw}{_Scan}{_Ant}-{label}{_alphalabel}{_colorlabel}{_suffix}.png", help='Template for output png files, default "%(default)s"') group_opts.add_argument('--title', default="{ms}{_field}{_Spw}{_Scan}{_Ant}{_title}{_Alphatitle}{_Colortitle}", help='Template for plot titles, default "%(default)s"') group_opts.add_argument('--xlabel', default="{xname}{_xunit}", help='Template for X axis labels, default "%(default)s"') group_opts.add_argument('--ylabel', default="{yname}{_yunit}", help='Template for X axis labels, default "%(default)s"') group_opts = parser.add_argument_group('Performance & tweaking') group_opts.add_argument("-d", "--debug", action='store_true', help="Enable debugging output") group_opts.add_argument('-z', '--row-chunk-size', type=int, metavar="NROWS", default=100000, help="""Row chunk size for dask-ms. Larger chunks may or may not be faster, but will certainly use more RAM.""") group_opts.add_argument('-j', '--num-parallel', type=int, metavar="N", default=1, help=f"""Run up to N renderers in parallel. Default is serial. Use -j0 to auto-set this to half the available cores ({DEFAULT_NUM_RENDERS} on this system). This is not necessarily faster, as they might all end up contending for disk I/O. This might also work against dask-ms's own intrinsic parallelism. You have been advised.""") group_opts.add_argument("--profile", action='store_true', help="Enable dask profiling output") # various hidden performance-testing options data_mappers.add_options(group_opts) data_plots.add_options(group_opts) options = parser.parse_args(argv) cmap = getattr(colorcet, options.cmap, None) if cmap is None: parser.error(f"unknown --cmap {options.cmap}") bmap = getattr(colorcet, options.bmap, None) if bmap is None: parser.error(f"unknown --bmap {options.bmap}") dmap = getattr(colorcet, options.dmap, None) if dmap is None: parser.error(f"unknown --dmap {options.dmap}") options.ms = options.ms.rstrip('/') if options.debug: shade_ms.log_console_handler.setLevel(logging.DEBUG) # pass options to shade_ms data_mappers.set_options(options) data_plots.set_options(options) # figure our list of plots to make if not options.xaxis: xaxes = ['TIME'] # Default xaxis if none is specified else: xaxes = list(itertools.chain(*[opt.split(",") for opt in options.xaxis])) if not options.yaxis: yaxes = ['DATA:amp'] # Default yaxis if none is specified else: yaxes = list(itertools.chain(*[opt.split(",") for opt in options.yaxis])) if len(xaxes) != len(yaxes): parser.error("--xaxis and --yaxis must be given the same number of times") def get_conformal_list(name, force_type=None, default=None): """ For all other settings, returns list same length as xaxes, or throws error if no conformance. Can also impose a type such as float (returning None for an empty string) """ optlist = getattr(options, name, None) if not optlist: return [default]*len(xaxes) # stick all lists together elems = list(itertools.chain(*[opt.split(",") for opt in optlist])) if len(elems) > 1 and len(elems) != len(xaxes): parser.error(f"--{name} must be given the same number of times as --xaxis, or else just once") # convert type if force_type: elems = [force_type(x) if x and x.lower() != "none" else None for x in elems] if len(elems) != len(xaxes): elems = [elems[0]]*len(xaxes) return elems # get list of columns and plot limites of the same length if not options.col: options.col = ["DATA"] columns = get_conformal_list('col') xmins = get_conformal_list('xmin', float) xmaxs = get_conformal_list('xmax', float) ymins = get_conformal_list('ymin', float) ymaxs = get_conformal_list('ymax', float) aaxes = get_conformal_list('aaxis') areds = get_conformal_list('ared', str, 'mean') caxes = get_conformal_list('colour_by') cmins = get_conformal_list('cmin', float) cmaxs = get_conformal_list('cmax', float) cnums = get_conformal_list('cnum', int, default=DEFAULT_CNUM) # check min/max if any([(a is None)^(b is None) for a, b in zip(xmins, xmaxs)]): parser.error("--xmin/--xmax must be either both set, or neither") if any([(a is None)^(b is None) for a, b in zip(ymins, ymaxs)]): parser.error("--xmin/--xmax must be either both set, or neither") if any([(a is None)^(b is None) for a, b in zip(ymins, ymaxs)]): parser.error("--cmin/--cmax must be either both set, or neither") # check chan slice def parse_slice_spec(spec, name): if spec: try: spec_elems = [int(x) if x else None for x in spec.split(":", 2)] except ValueError: parser.error(f"invalid selection --{name} {spec}") return slice(*spec_elems), spec_elems else: return slice(None), [] chanslice, chanslice_spec = parse_slice_spec(options.chan, name="chan") log.info(" ".join(sys.argv)) blank() ms = MSInfo(options.ms, log=log) blank() log.info(": Data selected for plotting:") group_cols = ['FIELD_ID', 'DATA_DESC_ID'] mytaql = [] class Subset(object): pass subset = Subset() if options.ant != 'all' or options.ant_num: if options.ant_num: ant_subset = set() for spec in options.ant_num.split(","): if re.fullmatch(r"\d+", spec): ant_subset.add(int(spec)) else: ant_subset.update(ms.all_antenna.numbers[parse_slice_spec(spec, "ant-num")[0]]) subset.ant = ms.antenna.get_subset(sorted(ant_subset)) else: subset.ant = ms.antenna.get_subset(options.ant) log.info(f"Antenna name(s) : {' '.join(subset.ant.names)}") mytaql.append("||".join([f'ANTENNA1=={ant}||ANTENNA2=={ant}' for ant in subset.ant.numbers])) else: subset.ant = ms.antenna log.info('Antenna(s) : all') if options.iter_antenna: raise NotImplementedError("iteration over antennas not currently supported") if options.baseline != 'all': subset.baseline = OrderedDict() for blspec in options.baseline.split(","): match = re.fullmatch(r"(\w+)-(\w+)", blspec) ant1 = match and ms.antenna[match.group(1)] ant2 = match and ms.antenna[match.group(2)] if ant1 is None or ant2 is None: raise ValueError("invalid baseline '{blspec}'") subset.baseline[blspec] = (ant1, ant2) # group_cols.append('ANTENNA1') log.info(f"Baseline(s) : {' '.join(subset.baseline.keys())}") mytaql.append("||".join([f'(ANTENNA1=={ant1}&&ANTENNA2=={ant2})||(ANTENNA1=={ant2}&&ANTENNA2=={ant1})' for ant1, ant2 in subset.baseline.values()])) else: log.info('Baseline(s) : all') if options.field != 'all': subset.field = ms.field.get_subset(options.field) log.info(f"Field(s) : {' '.join(subset.field.names)}") # here for now, workaround for https://github.com/ska-sa/dask-ms/issues/100, should be inside if clause mytaql.append("||".join([f'FIELD_ID=={fld}' for fld in subset.field.numbers])) else: subset.field = ms.field log.info('Field(s) : all') if options.spw != 'all': subset.spw = ms.spw.get_subset(options.spw) log.info(f"SPW(s) : {' '.join(subset.spw.names)}") mytaql.append("||".join([f'DATA_DESC_ID=={ddid}' for ddid in subset.spw.numbers])) else: subset.spw = ms.spw log.info(f'SPW(s) : all') if options.scan != 'all': subset.scan = ms.scan.get_subset(options.scan) log.info(f"Scan(s) : {' '.join(subset.scan.names)}") mytaql.append("||".join([f'SCAN_NUMBER=={n}' for n in subset.scan.numbers])) else: subset.scan = ms.scan log.info('Scan(s) : all') if options.iter_scan: group_cols.append('SCAN_NUMBER') if chanslice == slice(None): log.info('Channels : all') else: log.info(f"Channels : {':'.join(str(x) if x is not None else '' for x in chanslice_spec)}") mytaql = ' && '.join([f"({t})" for t in mytaql]) if mytaql else '' options.corr = options.corr.upper() if options.corr == "ALL": subset.corr = ms.corr else: # recognize shortcut when it's just Stokes indices, convert to list if re.fullmatch(r"[IQUV]+", options.corr): options.corr = ",".join(options.corr) subset.corr = ms.all_corr.get_subset(options.corr) log.info(f"Corr/Stokes : {' '.join(subset.corr.names)}") blank() # figure out list of plots to make all_plots = [] # This will be True if any of the specified axes change with correlation have_corr_dependence = False # now go create definitions for xaxis, yaxis, default_column, caxis, aaxis, ared, xmin, xmax, ymin, ymax, cmin, cmax, cnum in \ zip(xaxes, yaxes, columns, caxes, aaxes, areds, xmins, xmaxs, ymins, ymaxs, cmins, cmaxs, cnums): # get axis specs xspecs = DataAxis.parse_datum_spec(xaxis, default_column, ms=ms) yspecs = DataAxis.parse_datum_spec(yaxis, default_column, ms=ms) aspecs = DataAxis.parse_datum_spec(aaxis, default_column, ms=ms) if aaxis else [None] * 4 cspecs = DataAxis.parse_datum_spec(caxis, default_column, ms=ms) if caxis else [None] * 4 # parse axis specifications xfunction, xcolumn, xcorr, xitercorr = xspecs yfunction, ycolumn, ycorr, yitercorr = yspecs afunction, acolumn, acorr, aitercorr = aspecs cfunction, ccolumn, ccorr, citercorr = cspecs # does anything here depend on correlation? datum_itercorr = (xitercorr or yitercorr or aitercorr or citercorr) if datum_itercorr: have_corr_dependence = True if "FLAG" in (xcolumn, ycolumn, acolumn, ccolumn) or "FLAG_ROW" in (xcolumn, ycolumn, acolumn, ccolumn): if not options.noflags: log.info(": plotting a flag column implies that flagged data will not be masked") options.noflags = True # do we iterate over correlations/Stokes to make separate plots now? if datum_itercorr and options.iter_corr: corr_list = subset.corr.numbers else: corr_list = [None] def describe_corr(corrvalue): """Returns list of correlation labels corresponding to this corr setting""" if corrvalue is None: return subset.corr.names elif corrvalue is False: return [] else: return [ms.all_corr.names[corrvalue]] for corr in corr_list: plot_xcorr = corr if xcorr is None else xcorr # False if no corr in datum, None if all, else set to iterant or to fixed value plot_ycorr = corr if ycorr is None else ycorr plot_acorr = corr if acorr is None else acorr plot_ccorr = corr if ccorr is None else ccorr xdatum = DataAxis.register(xfunction, xcolumn, plot_xcorr, ms=ms, minmax=(xmin, xmax), subset=subset) ydatum = DataAxis.register(yfunction, ycolumn, plot_ycorr, ms=ms, minmax=(ymin, ymax), subset=subset) adatum = afunction and DataAxis.register(afunction, acolumn, plot_acorr, ms=ms, subset=subset) cdatum = cfunction and DataAxis.register(cfunction, ccolumn, plot_ccorr, ms=ms, minmax=(cmin, cmax), ncol=cnum, subset=subset) # figure out plot properties -- basically construct a descriptive name and label # looks complicated, but we're just trying to figure out what to put in the plot title... props = dict() titles = [] labels = [] # start with column and correlation(s) if ycolumn and not ydatum.mapper.column: # only put column if not fixed by mapper titles.append(ycolumn) labels.append(col_to_label(ycolumn)) titles += describe_corr(plot_ycorr) labels += describe_corr(plot_ycorr) titles += [ydatum.mapper.fullname, "vs"] if ydatum.function: labels.append(ydatum.function) # add x column/subset.corr, if different if xcolumn and (xcolumn != ycolumn or not xdatum.function) and not xdatum.mapper.column: titles.append(xcolumn) labels.append(col_to_label(xcolumn)) if plot_xcorr != plot_ycorr: titles += describe_corr(plot_xcorr) labels += describe_corr(plot_xcorr) titles += [xdatum.mapper.fullname] if xdatum.function: labels.append(xdatum.function) props['title'] = " ".join(titles) props['label'] = "-".join(labels) # build up intensity label if afunction: titles, labels = [ared], [ared] if acolumn and (acolumn != xcolumn or acolumn != ycolumn) and adatum.mapper.column is None: titles.append(acolumn) labels.append(col_to_label(acolumn)) if plot_acorr and (plot_acorr != plot_xcorr or plot_acorr != plot_ycorr): titles += describe_corr(plot_acorr) labels += describe_corr(plot_acorr) titles += [adatum.mapper.fullname] if adatum.function: labels.append(adatum.function) props['alpha_title'] = " ".join(titles) props['alpha_label'] = "-".join(labels) else: props['alpha_title'] = props['alpha_label'] = '' # build up color-by label if cfunction: titles, labels = [], [] if ccolumn and (ccolumn != xcolumn or ccolumn != ycolumn) and cdatum.mapper.column is None: titles.append(ccolumn) labels.append(col_to_label(ccolumn)) if plot_ccorr and (plot_ccorr != plot_xcorr or plot_ccorr != plot_ycorr): titles += describe_corr(plot_ccorr) labels += describe_corr(plot_ccorr) if cdatum.mapper.fullname: titles.append(cdatum.mapper.fullname) if cdatum.function: labels.append(cdatum.function) if not cdatum.discretized_delta: if not cdatum.discretized_labels or len(cdatum.discretized_labels) > cdatum.nlevels: titles.append(f"(modulo {cdatum.nlevels})") props['color_title'] = " ".join(titles) props['color_label'] = "-".join(labels) else: props['color_title'] = props['color_label'] = '' all_plots.append((props, xdatum, ydatum, adatum, ared, cdatum)) log.debug(f"adding plot for {props['title']}") join_corrs = not options.iter_corr and len(subset.corr) > 1 and have_corr_dependence log.info(' : you have asked for {} plots employing {} unique datums'.format(len(all_plots), len(DataAxis.all_axes))) if not len(all_plots): sys.exit(0) log.debug(f"taql is {mytaql}, group_cols is {group_cols}, join subset.corr is {join_corrs}") dataframes, np = \ data_plots.get_plot_data(ms, group_cols, mytaql, ms.chan_freqs, chanslice=chanslice, subset=subset, noflags=options.noflags, noconj=options.noconj, iter_field=options.iter_field, iter_spw=options.iter_spw, iter_scan=options.iter_scan, join_corrs=join_corrs, row_chunk_size=options.row_chunk_size) log.info(f": rendering {len(dataframes)} dataframes with {np:.3g} points into {len(all_plots)} plot types") ## each dataframe is an instance of the axes being iterated over -- on top of that, we need to iterate over plot types # dictionary of substitutions for filename and title keys = {} keys['ms'] = os.path.basename(os.path.splitext(options.ms.rstrip("/"))[0]) keys['timestamp'] = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # dictionary of titles for these identifiers titles = dict(field="", field_num="field", scan="scan", spw="spw", antenna="ant", colortitle="coloured by") keys['field_num'] = subset.field.numbers if options.field != 'all' else '' keys['field'] = subset.field.names if options.field != 'all' else '' keys['scan'] = subset.scan.names if options.scan != 'all' else '' keys['ant'] = subset.ant.names if options.ant != 'all' else '' keys['spw'] = subset.spw.names if options.spw != 'all' else '' keys['suffix'] = suffix = options.suffix.format(**options.__dict__) if options.suffix else '' keys['_suffix'] = f".{suffix}" if suffix else '' def generate_string_from_keys(template, keys, listsep=" ", titlesep=" ", prefix=" "): """Converts list of keys into a string suitable for plot titles or filenames. listsep is used to separate list elements: " " for plot titles, "_" for filenames. titlesep is used to separate names from values: " " for plot titles, "_" for filenames prefix is used to prefix {_key} substitutions: ", " for titles, "-" for filenames """ full_keys = keys.copy() # convert lists to strings, add Capitalized keys with titles for key, value in keys.items(): capkey = key.title() if value == '': full_keys[capkey] = '' else: if type(value) is list: # e.g. scan=[1,2] becomes scan="1 2" full_keys[key] = value = listsep.join(map(str, value)) if key in titles: # e.g. Scan="scan 1 2" full_keys[capkey] = f"{titles[key]}{titlesep}{value}" else: full_keys[capkey] = value # add _keys which supply prefixes full_keys.update({f"_{key}": (f"{prefix}{value}" if value else '') for key, value in full_keys.items()}) # finally, format return template.format(**full_keys) jobs = [] executor = None def render_single_plot(df, xdatum, ydatum, adatum, ared, cdatum, pngname, title, xlabel, ylabel): """Renders a single plot. Make this a function since we might call it in parallel""" log.info(f": rendering {pngname}") normalize = options.norm if normalize == "auto": normalize = "log" if cdatum is not None else "eq_hist" if options.profile: context = dask.diagnostics.ResourceProfiler else: context = nullcontext with context() as profiler: result = data_plots.create_plot(df, xdatum, ydatum, adatum, ared, cdatum, cmap=cmap, bmap=bmap, dmap=dmap, normalize=normalize, xlabel=xlabel, ylabel=ylabel, title=title, pngname=pngname, options=options) if result: log.info(f' : wrote {pngname}') if profiler is not None: profile_file = os.path.splitext(pngname)[0] + ".prof.html" dask.diagnostics.visualize(profiler, file_path=profile_file, show=False, save=True) log.info(f' : wrote profiler info to {profile_file}') for (fld, spw, scan, antenna), df in dataframes.items(): # update keys to be substituted into title and filename if fld is not None: keys['field_num'] = fld keys['field'] = ms.field[fld] if spw is not None: keys['spw'] = spw if scan is not None: keys['scan'] = scan if antenna is not None: keys['ant'] = ms.all_antenna[antenna] # now loop over plot types for props, xdatum, ydatum, adatum, ared, cdatum in all_plots: keys.update(title=props['title'], label=props['label'], alphatitle=props['alpha_title'], alphalabel=props['alpha_label'], colortitle=props['color_title'], colorlabel=props['color_label'], xname=xdatum.fullname, yname=ydatum.fullname, xunit=xdatum.mapper.unit, yunit=ydatum.mapper.unit) pngname = generate_string_from_keys(options.pngname, keys, "_", "_", "-") title = generate_string_from_keys(options.title, keys, " ", " ", ", ") xlabel = generate_string_from_keys(options.xlabel, keys, " ", " ", ", ") ylabel = generate_string_from_keys(options.ylabel, keys, " ", " ", ", ") if options.dir: pngname = os.path.join(options.dir, pngname) # make output directory, if needed dirname = os.path.dirname(pngname) if dirname and not os.path.exists(dirname): os.mkdir(dirname) log.info(f' : created output directory {dirname}') if options.num_parallel < 2 or len(all_plots) < 2: render_single_plot(df, xdatum, ydatum, adatum, ared, cdatum, pngname, title, xlabel, ylabel) else: from concurrent.futures import ThreadPoolExecutor executor = ThreadPoolExecutor(options.num_parallel) log.info(f' : submitting job for {pngname}') jobs.append(executor.submit(render_single_plot, df, xdatum, ydatum, adatum, ared, cdatum, pngname, title, xlabel, ylabel)) # wait for jobs to finish if executor is not None: log.info(f' : waiting for {len(jobs)} jobs to complete') for job in jobs: job.result() clock_stop = time.time() elapsed = str(round((clock_stop-clock_start), 2)) log.info('Total time : %s seconds' % (elapsed)) log.info('Finished') blank()
def main(argv): # TODO: I want to inspect a different time profile method clock_start = time.time() # default # of CPUs # --- dealing with input arguments --- parser, optimization_opts = cli() # various hidden performance-testing options data_mappers.add_options(optimization_opts) data_plots.add_options(optimization_opts) options = parser.parse_args(argv) if options.debug: shade_ms.log_console_handler.setLevel(logging.DEBUG) # pass options to shade_ms data_mappers.set_options(options) data_plots.set_options(options) # set colormaps cmap = data_plots.get_colormap(options.cmap) bmap = data_plots.get_colormap(options.bmap) dmap = data_plots.get_colormap(options.dmap) # figure our list of plots to make # define plot parameters [xaxes, yaxes, columns, caxes, aaxes, areds, xmins, xmaxs, ymins, ymaxs, amins, amaxs, cmins, cmaxs, cnums] = parse_plot_spec(parser, options) # check markup arguments extra_markup = [] for funcname, funcargs in options.markup or []: from matplotlib.axes import Axes if funcname not in dir(Axes): parser.error(f"unknown function given in --markup {funcname}") args = yaml.safe_load(funcargs) kwargs = [] if type(args) not in {list, dict}: parser.error(f"invalid arguments to --markup {funcname} {funcargs}") if type(args) is list: if len(args) > 0 and type(args[-1]) is dict: kwargs = args[-1] args = args[:-1] else: kwargs = args args = [] extra_markup.append((funcname, args, kwargs)) # check vline and hline for attr in 'hline', 'vline': if getattr(options, attr, None): for spec in getattr(options, attr).split(","): match = re.match('(.*?)((--|-|-\\.|:)([a-zA-Z#].*)?)?$', spec) try: coord = float(match.group(1)) except ValueError: coord = None parser.error(f"invalid --{attr} setting '{spec}'") linestyle = match.group(3) or '-' extra_markup.append((f"ax{attr}", [coord], dict(ls=linestyle, color=match.group(4) or "black"))) # re-check that kwargs are valid for funcname, args, kwargs in extra_markup: log.info(f"markup: {funcname} *{args} **{kwargs})") if not all([re.match('^\w+$', kw) for kw in kwargs.keys()]): log.error("the above is not a valid markup specification, please fix") sys.exit(1) # check chan slice try: chanslice, chanslice_spec = parse_slice_spec(options.chan) except ValueError: # parser.error(f"invalid selection --{'chan'} {options.chan}") parser.error(f"invalid selection --chan {options.chan}") # issue warning if only a single antenna is specified num_ants_warning = False if options.ant_num: specs = options.ant_num.split(',') if len(specs) == 1 and re.fullmatch(r"\d+", specs[0]): ant_spec = "ant{}".format(int(specs[0])-1) num_ants_warning = True # check ant selection before processing for spec in options.ant_num.split(","): try: antslice, antslice_spec = parse_slice_spec(spec) except ValueError: parser.error(f"invalid selection --{'ant-num'} {options.ant_num}") elif not 'all' in options.ant: specs = options.ant.split(',') if len(specs) == 1: ant_spec = "{}".format(specs[0]) num_ants_warning = True if num_ants_warning: msg = f"Suggested usage: '--baseline {ant_spec}-*' to select all baselines to {ant_spec}." warnings.warn(msg, category=UserWarning) log.info(" ".join(sys.argv)) separator() # --- dealing with input arguments --- ms = MSInfo(options.ms, log=log) separator() log.info(": Data selected for plotting:") group_cols = ['FIELD_ID', 'DATA_DESC_ID'] # --- building SQL query -- mytaql = [] class Subset(object): pass subset = Subset() if not 'all' in options.ant or options.ant_num: if options.ant_num: ant_subset = set() for spec in options.ant_num.split(","): if re.fullmatch(r"\d+", spec): ant_subset.add(int(spec)) else: antslice, antslice_spec = parse_slice_spec(spec) ant_subset.update(ms.all_antenna.numbers[antslice]) subset.ant = ms.antenna.get_subset(sorted(ant_subset)) else: subset.ant = ms.antenna.get_subset(options.ant) log.info(f"Antenna name(s) : {' '.join(subset.ant.names)}") antnum_set = f"[{','.join(map(str, subset.ant.numbers))}]" mytaql.append(f"ANTENNA1 IN {antnum_set} && ANTENNA2 IN {antnum_set}") else: subset.ant = ms.antenna log.info('Antenna(s) : all') if options.baseline == 'all': log.info('Baseline(s) : all') subset.baseline = ms.baseline elif options.baseline == 'noautocorr': log.info('Baseline(s) : all except autocorrelations') subset.baseline = ms.all_baseline.get_subset([i for i in ms.baseline.numbers if ms.baseline_lengths[i] !=0 ]) mytaql.append("ANTENNA1!=ANTENNA2") elif options.baseline == 'autocorr': log.info('Baseline(s) : autocorrelations') subset.baseline = ms.all_baseline.get_subset([i for i in ms.baseline.numbers if ms.baseline_lengths[i] == 0]) mytaql.append("ANTENNA1==ANTENNA2") else: bls = set() a1a2 = set() for blspec in options.baseline.split(","): match = re.fullmatch(r"(\w+)-(\w*|[*])", blspec) ant1 = match and ms.antenna[match.group(1)] ant2 = match and (ms.antenna[match.group(2)] if match.group(2) not in ['', '*'] else '*') if ant1 is None or ant2 is None: raise ValueError(f"invalid baseline '{blspec}'") if ant2 == '*': ant2set = ms.all_antenna.numbers else: ant2set = [ant2] # loop for ant2 in ant2set: a1, a2 = min(ant1, ant2), max(ant1, ant2) a1a2.add((a1, a2)) bls.add(ms.baseline_number(a1, a2)) # group_cols.append('ANTENNA1') subset.baseline = ms.all_baseline.get_subset(sorted(bls)) log.info(f"Baseline(s) : {' '.join(subset.baseline.names)}") mytaql.append("||".join([f'(ANTENNA1=={ant1}&&ANTENNA2=={ant2})||(ANTENNA1=={ant2}&&ANTENNA2=={ant1})' for ant1, ant2 in a1a2])) if options.field != 'all': subset.field = ms.field.get_subset(options.field) log.info(f"Field(s) : {' '.join(subset.field.names)}") # here for now, workaround for https://github.com/ska-sa/dask-ms/issues/100, should be inside if clause mytaql.append("||".join([f'FIELD_ID=={fld}' for fld in subset.field.numbers])) else: subset.field = ms.field log.info('Field(s) : all') if options.spw != 'all': subset.spw = ms.spw.get_subset(options.spw) log.info(f"SPW(s) : {' '.join(subset.spw.names)}") mytaql.append("||".join([f'DATA_DESC_ID=={ddid}' for ddid in subset.spw.numbers])) else: subset.spw = ms.spw log.info(f'SPW(s) : all') if options.scan != 'all': subset.scan = ms.scan.get_subset(options.scan, allow_numeric_indices=False) log.info(f"Scan(s) : {' '.join(subset.scan.names)}") mytaql.append("||".join([f'SCAN_NUMBER=={n}' for n in subset.scan.numbers])) else: subset.scan = ms.scan log.info('Scan(s) : all') if options.iter_scan: group_cols.append('SCAN_NUMBER') if chanslice == slice(None): log.info('Channels : all') else: log.info(f"Channels : {':'.join(str(x) if x is not None else '' for x in chanslice_spec)}") mytaql = ' && '.join([f"({t})" for t in mytaql]) if mytaql else '' # --- building SQL query -- if options.corr == "ALL": subset.corr = ms.corr else: # recognize shortcut when it's just Stokes indices, convert to list if re.fullmatch(r"[IQUV]+", options.corr): options.corr = ",".join(options.corr) subset.corr = ms.all_corr.get_subset(options.corr) log.info(f"Corr/Stokes : {' '.join(subset.corr.names)}") separator() # check minmax cache msbase = os.path.splitext(os.path.basename(options.ms))[0] cache_file = options.lim_file.format(ms=msbase) if options.dir and not "/" in cache_file: cache_file = os.path.join(options.dir, cache_file) # try to load the minmax cache file if not os.path.exists(cache_file): minmax_cache = {} else: log.info(f"loading minmax cache from {cache_file}") try: minmax_cache = json.load(open(cache_file, "rt")) if type(minmax_cache) is not dict: raise TypeError("cache content is not a dict") except Exception as exc: log.error(f"error reading cache file: {exc}. Minmax cache will be reset.") minmax_cache = {} # figure out list of plots to make all_plots = [] # This will be True if any of the specified axes change with correlation have_corr_dependence = False # now go create definitions for xaxis, yaxis, default_column, caxis, aaxis, ared, xmin, xmax, ymin, ymax, amin, amax, cmin, cmax, cnum in \ zip(xaxes, yaxes, columns, caxes, aaxes, areds, xmins, xmaxs, ymins, ymaxs, amins, amaxs, cmins, cmaxs, cnums): # get axis specs xspecs = DataAxis.parse_datum_spec(xaxis, default_column, ms=ms) yspecs = DataAxis.parse_datum_spec(yaxis, default_column, ms=ms) aspecs = DataAxis.parse_datum_spec(aaxis, default_column, ms=ms) if aaxis else [None] * 4 cspecs = DataAxis.parse_datum_spec(caxis, default_column, ms=ms) if caxis else [None] * 4 # parse axis specifications xfunction, xcolumn, xcorr, xitercorr = xspecs yfunction, ycolumn, ycorr, yitercorr = yspecs afunction, acolumn, acorr, aitercorr = aspecs cfunction, ccolumn, ccorr, citercorr = cspecs # does anything here depend on correlation? datum_itercorr = (xitercorr or yitercorr or aitercorr or citercorr) if datum_itercorr: have_corr_dependence = True if "FLAG" in (xcolumn, ycolumn, acolumn, ccolumn) or "FLAG_ROW" in (xcolumn, ycolumn, acolumn, ccolumn): if not options.noflags: log.info(": plotting a flag column implies that flagged data will not be masked") options.noflags = True # do we iterate over correlations/Stokes to make separate plots now? if datum_itercorr and options.iter_corr: corr_list = subset.corr.numbers else: corr_list = [None] def describe_corr(corrvalue): """Returns list of correlation labels corresponding to this corr setting""" if corrvalue is None: return subset.corr.names elif corrvalue is False: return [] else: return [ms.all_corr.names[corrvalue]] for corr in corr_list: plot_xcorr = corr if xcorr is None else xcorr # False if no corr in datum, None if all, else set to iterant or to fixed value plot_ycorr = corr if ycorr is None else ycorr plot_acorr = corr if acorr is None else acorr plot_ccorr = corr if ccorr is None else ccorr xdatum = DataAxis.register(xfunction, xcolumn, plot_xcorr, ms=ms, minmax=(xmin, xmax), subset=subset, minmax_cache=minmax_cache if options.xlim_load else None) ydatum = DataAxis.register(yfunction, ycolumn, plot_ycorr, ms=ms, minmax=(ymin, ymax), subset=subset, minmax_cache=minmax_cache if options.ylim_load else None) adatum = afunction and DataAxis.register(afunction, acolumn, plot_acorr, ms=ms, minmax=(amin, amax), subset=subset) cdatum = cfunction and DataAxis.register(cfunction, ccolumn, plot_ccorr, ms=ms, minmax=(cmin, cmax), ncol=cnum, subset=subset, minmax_cache=minmax_cache if options.clim_load else None) # figure out plot properties -- basically construct a descriptive name and label # looks complicated, but we're just trying to figure out what to put in the plot title... props = dict() titles = [] labels = [] # start with column and correlation(s) if ycolumn and not ydatum.mapper.column: # only put column if not fixed by mapper titles.append(ycolumn) labels.append(col_to_label(ycolumn)) titles += describe_corr(plot_ycorr) labels += describe_corr(plot_ycorr) if ydatum.mapper.fullname: titles += [ydatum.mapper.fullname] titles += ["vs"] if ydatum.function: labels.append(ydatum.function) # add x column/subset.corr, if different if xcolumn and (xcolumn != ycolumn or not xdatum.function) and not xdatum.mapper.column: titles.append(xcolumn) labels.append(col_to_label(xcolumn)) if plot_xcorr is not plot_ycorr: titles += describe_corr(plot_xcorr) labels += describe_corr(plot_xcorr) if xdatum.mapper.fullname: titles += [xdatum.mapper.fullname] if xdatum.function: labels.append(xdatum.function) props['title'] = " ".join(titles) props['label'] = "-".join(labels) # build up intensity label if afunction: titles, labels = [ared], [ared] if acolumn and (acolumn != xcolumn or acolumn != ycolumn) and adatum.mapper.column is None: titles.append(acolumn) labels.append(col_to_label(acolumn)) if plot_acorr is not None and plot_acorr is not False and \ (plot_acorr is not plot_xcorr or plot_acorr is not plot_ycorr): titles += describe_corr(plot_acorr) labels += describe_corr(plot_acorr) titles += [adatum.mapper.fullname] if adatum.function: labels.append(adatum.function) props['alpha_title'] = " ".join(titles) props['alpha_label'] = "-".join(labels) else: props['alpha_title'] = props['alpha_label'] = '' # build up color-by label if cfunction: titles, labels = [], [] if ccolumn and (ccolumn != xcolumn or ccolumn != ycolumn) and cdatum.mapper.column is None: titles.append(ccolumn) labels.append(col_to_label(ccolumn)) if plot_ccorr is not None and plot_ccorr is not False and \ (plot_ccorr is not plot_xcorr or plot_ccorr is not plot_ycorr): titles += describe_corr(plot_ccorr) labels += describe_corr(plot_ccorr) if cdatum.mapper.fullname: titles.append(cdatum.mapper.fullname) if cdatum.function: labels.append(cdatum.function) props['color_title'] = " ".join(titles) props['color_label'] = "-".join(labels) props['color_modulo'] = f"(into {cdatum.nlevels} colours)" else: props['color_title'] = props['color_label'] = '' all_plots.append((props, xdatum, ydatum, adatum, ared, cdatum)) log.debug(f"adding plot for {props['title']}") # reset minmax cache if requested if options.lim_save_reset: minmax_cache = {} join_corrs = not options.iter_corr and len(subset.corr) > 1 and have_corr_dependence log.info(' : you have asked for {} plots employing {} unique datums'.format(len(all_plots), len(DataAxis.all_axes))) if not len(all_plots): sys.exit(0) log.debug(f"taql is {mytaql}, group_cols is {group_cols}, join subset.corr is {join_corrs}") dataframes, index_subsets, np = \ data_plots.get_plot_data(ms, group_cols, mytaql, ms.chan_freqs, chanslice=chanslice, subset=subset, noflags=options.noflags, noconj=options.noconj, iter_field=options.iter_field, iter_spw=options.iter_spw, iter_scan=options.iter_scan, iter_ant=options.iter_ant, iter_baseline=options.iter_baseline, join_corrs=join_corrs, row_chunk_size=options.row_chunk_size) if len(dataframes) < 1: log.warn("No data for selection subset") log.info(f": rendering {len(dataframes)} dataframes with {np:.3g} points into {len(all_plots)} plot types") ## each dataframe is an instance of the axes being iterated over -- on top of that, we need to iterate over plot types # dictionary of substitutions for filename and title keys = {} keys['ms'] = os.path.basename(os.path.splitext(options.ms.rstrip("/"))[0]) keys['timestamp'] = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # dictionary of titles for these identifiers titles = dict(field="", field_num="field", scan="scan", spw="spw", antenna="ant", colortitle="coloured by") keys['field_num'] = subset.field.numbers if options.field != 'all' else '' keys['field'] = subset.field.names if options.field != 'all' else '' keys['scan'] = subset.scan.names if options.scan != 'all' else '' keys['ant'] = subset.ant.names if options.ant != 'all' else '' ## TODO: also handle ant-num settings keys['spw'] = subset.spw.names if options.spw != 'all' else '' keys['baseline'] = None keys['suffix'] = suffix = options.suffix.format(**options.__dict__) if options.suffix else '' keys['_suffix'] = f".{suffix}" if suffix else '' def generate_string_from_keys(template, keys, listsep=" ", titlesep=" ", prefix=" "): """Converts list of keys into a string suitable for plot titles or filenames. listsep is used to separate list elements: " " for plot titles, "_" for filenames. titlesep is used to separate names from values: " " for plot titles, "_" for filenames prefix is used to prefix {_key} substitutions: ", " for titles, "-" for filenames """ full_keys = keys.copy() # convert lists to strings, add Capitalized keys with titles for key, value in keys.items(): capkey = key.title() if value == '': full_keys[capkey] = '' else: if type(value) is list: # e.g. scan=[1,2] becomes scan="1 2" full_keys[key] = value = listsep.join(map(str, value)) if key in titles: # e.g. Scan="scan 1 2" full_keys[capkey] = f"{titles[key]}{titlesep}{value}" else: full_keys[capkey] = value # add _keys which supply prefixes full_keys.update({f"_{key}": (f"{prefix}{value}" if value else '') for key, value in full_keys.items()}) # finally, format return template.format(**full_keys) jobs = [] executor = None def render_single_plot(df, subset, xdatum, ydatum, adatum, ared, cdatum, pngname, title, xlabel, ylabel): """Renders a single plot. Make this a function since we might call it in parallel""" log.info(f": rendering {pngname}") normalize = options.norm if normalize == "auto": normalize = "log" if cdatum is not None else ("eq_hist" if adatum is None else 'linear') if options.profile: context = dask.diagnostics.ResourceProfiler else: context = nullcontext with context() as profiler: result = data_plots.create_plot(df, subset, xdatum, ydatum, adatum, ared, cdatum, cmap=cmap, bmap=bmap, dmap=dmap, normalize=normalize, min_alpha=options.min_alpha, saturate_alpha=options.saturate_alpha, saturate_percentile=options.saturate_perc, xlabel=xlabel, ylabel=ylabel, title=title, pngname=pngname, extra_markup=extra_markup, minmax_cache=minmax_cache, options=options) if result: log.info(f' : wrote {pngname}') if profiler is not None: profile_file = os.path.splitext(pngname)[0] + ".prof.html" dask.diagnostics.visualize(profiler, file_path=profile_file, show=False, save=True) log.info(f' : wrote profiler info to {profile_file}') for (fld, spw, scan, antenna_or_baseline), df in dataframes.items(): subset = index_subsets[fld, spw, scan, antenna_or_baseline] # update keys to be substituted into title and filename if fld is not None: keys['field_num'] = fld keys['field'] = ms.field[fld] if spw is not None: keys['spw'] = spw if scan is not None: keys['scan'] = scan if antenna_or_baseline is not None: if options.iter_ant: keys['ant'] = ms.all_antenna[antenna_or_baseline] elif options.iter_baseline: keys['baseline'] = ms.all_baseline[antenna_or_baseline] # now loop over plot types for props, xdatum, ydatum, adatum, ared, cdatum in all_plots: keys.update(title=props['title'], label=props['label'], alphatitle=props['alpha_title'], alphalabel=props['alpha_label'], colortitle=props['color_title'], colorlabel=props['color_label'], xname=xdatum.fullname, yname=ydatum.fullname, xunit=xdatum.mapper.unit, yunit=ydatum.mapper.unit) if cdatum is not None and cdatum.is_discrete and not cdatum.discretized_labels: keys['colortitle'] += ' '+props['color_modulo'] pngname = generate_string_from_keys(options.pngname, keys, "_", "_", "-") title = generate_string_from_keys(options.title, keys, " ", " ", ", ") xlabel = generate_string_from_keys(options.xlabel, keys, " ", " ", ", ") ylabel = generate_string_from_keys(options.ylabel, keys, " ", " ", ", ") if options.dir: pngname = os.path.join(options.dir, pngname) # make output directory, if needed dirname = os.path.dirname(pngname) if dirname and not os.path.exists(dirname): os.mkdir(dirname) log.info(f' : created output directory {dirname}') if options.num_parallel < 2 or len(all_plots) < 2: render_single_plot(df, subset, xdatum, ydatum, adatum, ared, cdatum, pngname, title, xlabel, ylabel) else: from concurrent.futures import ThreadPoolExecutor executor = ThreadPoolExecutor(options.num_parallel) log.info(f' : submitting job for {pngname}') jobs.append(executor.submit(render_single_plot, df, subset, xdatum, ydatum, adatum, ared, cdatum, pngname, title, xlabel, ylabel)) # wait for jobs to finish if executor is not None: log.info(f' : waiting for {len(jobs)} jobs to complete') for job in jobs: job.result() clock_stop = time.time() elapsed = str(round((clock_stop-clock_start), 2)) log.info('Total time : %s seconds' % (elapsed)) if minmax_cache and options.lim_save: # ensure floats, because in64s and such cause errors minmax_cache = {axis: list(map(float, minmax)) for axis, minmax in minmax_cache.items()} with open(cache_file, "wt") as file: json.dump(minmax_cache, file, sort_keys=True, indent=4, separators=(',', ': ')) log.info(f"Saved minmax cache to {cache_file} (disable with --no-lim-save)") log.info('Finished') separator()
def __init__(self, column, function, corr, ms, minmax=None, ncol=None, label=None, subset=None): """See register() class method above. Not called directly.""" self.name = ":".join([ str(x) for x in (function, column, corr, minmax, ncol) if x is not None ]) self.ms = ms self.function = function # function to apply to column (see list of DataMappers below) self.corr = corr if corr != "all" else None self.nlevels = ncol self.minmax = vmin, vmax = tuple(minmax) if minmax is not None else ( None, None) self.label = label self._corr_reduce = None self.discretized_labels = None # filled for corrs and fields and so # set up discretized continuous axis if self.nlevels and vmin is not None and vmax is not None: self.discretized_delta = delta = (vmax - vmin) / self.nlevels self.discretized_bin_centers = numpy.arange( vmin + delta / 2, vmax, delta) else: self.discretized_delta = self.discretized_bin_centers = None self.mapper = data_mappers[function] # columns with labels? if function == 'CORR' or function == 'STOKES': corrset = set(subset.corr.names) corrset1 = corrset - set("IQUV") if corrset1 == corrset: name = "Correlation" elif not corrset1: name = "Stokes" else: name = "Correlation or Stokes" # special case of "corr" if corr is fixed: return constant value fixed here # When corr is fixed, we're creating a mapper for a know correlation. When corr is not fixed, # we're creating one for a mapper that will iterate over correlations if corr is not None: self.mapper = DataMapper(name, "", column=False, axis=-1, mapper=lambda x: corr) self.discretized_labels = subset.corr.names elif column == "FIELD_ID": self.discretized_labels = [ name for name in ms.field.names if name in subset.field ] elif column == "ANTENNA1" or column == "ANTENNA2": self.discretized_labels = [ name for name in ms.all_antenna.names if name in subset.ant ] elif column == "FLAG" or column == "FLAG_ROW": self.discretized_labels = ["F", "T"] # axis name self.fullname = self.mapper.fullname or column or '' # if labels were set up, adjust nlevels (but only down, never up) if self.discretized_labels is not None and self.nlevels is not None: self.nlevels = min(len(self.discretized_labels), self.nlevels) if self.function == "_": self.function = "" self.conjugate = self.mapper.conjugate self.timefreq_axis = self.mapper.axis # setup columns self._ufunc = None self.columns = () # does the mapper have no column (i.e. frequency)? if self.mapper.column is False: log.info( f'axis: {function}, range {self.minmax}, discretization {self.nlevels}' ) # does the mapper have a fixed column? This better be consistent elif self.mapper.column is not None: # if a mapper (such as "uv") implies a fixed column name, make sure it's consistent with what the user said if column and self.mapper.column != column: raise ValueError( f"'{function}' not applicable with column {column}") self.columns = (column, ) log.info( f'axis: {function}({column}), range {self.minmax}, discretization {self.nlevels}' ) # else arbitrary column else: log.info( f'axis: {function}({column}), corr {self.corr}, range {self.minmax}, discretization {self.nlevels}' ) # check for column arithmetic match = re.fullmatch(r"(\w+)([*/+-])(\w+)", column) if match: self.columns = (match.group(1), match.group(3)) # look up dask ufunc corresponding to arithmetic op self._ufunc = { '+': da.add, '*': da.multiply, '-': da.subtract, '/': da.divide }[match.group(2)] else: self.columns = (column, )
def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, chanslice, subset, noflags, noconj, iter_field, iter_spw, iter_scan, iter_ant, iter_baseline, join_corrs=False, row_chunk_size=100000): ms_cols = {'ANTENNA1', 'ANTENNA2'} ms_cols.update(msinfo.indexing_columns.keys()) if not noflags: ms_cols.update({'FLAG', 'FLAG_ROW'}) # get visibility columns for axis in DataAxis.all_axes.values(): ms_cols.update(axis.columns) total_num_points = 0 # total number of points to plot # output dataframes, indexed by (field, spw, scan, antenna_or_baseline) # If any of these axes is not being iterated over, then the index at that position is None output_dataframes = OrderedDict() # number of rows per each dataframe output_rows = OrderedDict() # output subsets of indexing columns, indexed by same tuple output_subsets = OrderedDict() if iter_ant: antenna_subsets = zip(subset.ant.numbers, subset.ant.names) else: antenna_subsets = [(None, None)] taql = mytaql for antenna, antname in antenna_subsets: if antenna is not None: taql = f"({mytaql})&&(ANTENNA1=={antenna} || ANTENNA2=={antenna})" if mytaql else \ f"(ANTENNA1=={antenna} || ANTENNA2=={antenna})" # add baselines to group columns if iter_baseline: group_cols = list(group_cols) + ["ANTENNA1", "ANTENNA2"] # get MS data msdata = daskms.xds_from_ms(msinfo.msname, columns=list(ms_cols), group_cols=group_cols, taql_where=taql, chunks=dict(row=row_chunk_size)) nrow = sum([len(group.row) for group in msdata]) if not nrow: continue if antenna is not None: log.info(f': Indexing sub-MS (antenna {antname}) and building dataframes ({nrow} rows, chunk size is {row_chunk_size})') else: log.info(f': Indexing MS and building dataframes ({nrow} rows, chunk size is {row_chunk_size})') # iterate over groups for group in msdata: if not len(group.row): continue ddid = group.DATA_DESC_ID # always present fld = group.FIELD_ID # always present if fld not in subset.field or ddid not in subset.spw: log.debug(f"field {fld} ddid {ddid} not in selection, skipping") continue scan = getattr(group, 'SCAN_NUMBER', None) # will be present if iterating over scans if iter_baseline: ant1 = getattr(group, 'ANTENNA1', None) # will be present if iterating over baselines ant2 = getattr(group, 'ANTENNA2', None) # will be present if iterating over baselines baseline = msinfo.baseline_number(ant1, ant2) else: baseline = None # Make frame key -- data subset corresponds to this frame dataframe_key = (fld if iter_field else None, ddid if iter_spw else None, scan if iter_scan else None, antenna if antenna is not None else baseline) # update subsets of MS indexing columns that we've seen for this dataframe output_subset1 = output_subsets.setdefault(dataframe_key, {column:set() for column in msinfo.indexing_columns.keys()}) for column, _ in msinfo.indexing_columns.items(): value = getattr(group, column) if np.isscalar(value): output_subset1[column].add(value) else: output_subset1[column].update(value.compute().data) # number of rows in dataframe nrows0 = output_rows.setdefault(dataframe_key, 0) # always read flags -- easier that way flag = group.FLAG if not noflags else None flag_row = group.FLAG_ROW if not noflags else None a1 = da.minimum(group.ANTENNA1.data, group.ANTENNA2.data) a2 = da.maximum(group.ANTENNA1.data, group.ANTENNA2.data) baselines = msinfo.baseline_number(a1, a2) freqs = chan_freqs[ddid] chans = xarray.DataArray(range(len(freqs)), dims=("chan",)) wavel = freq_to_wavel(freqs) extras = dict(chans=chans, freqs=freqs, wavel=wavel, rows=group.row, baselines=baselines) nchan = len(group.chan) if flag is not None: flag = flag[dict(chan=chanslice)] nchan = flag.shape[1] shape = (len(group.row), nchan) arrays = OrderedDict() shapes = OrderedDict() ddf = None num_points = 0 # counts number of new points generated for corr in subset.corr.numbers: # make dictionary of extra values for DataMappers extras['corr'] = corr # loop over datums to be computed for axis in DataAxis.all_axes.values(): value = arrays.get(axis.label) # a datum was already computed? if value is not None: # if not joining correlations, then that's the only one we'll need, so continue if not join_corrs: continue # joining correlations, and datum has a correlation dependence: compute another one if axis.corr is None: value = None if value is None: value = axis.get_value(group, corr, extras, flag=flag, flag_row=flag_row, chanslice=chanslice) # print(axis.label, value.compute().min(), value.compute().max()) num_points = max(num_points, value.size) if value.ndim == 0: shapes[axis.label] = () elif value.ndim == 1: timefreq_axis = axis.mapper.axis or 0 assert value.shape[0] == shape[timefreq_axis], \ f"{axis.mapper.fullname}: size {value.shape[0]}, expected {shape[timefreq_axis]}" shapes[axis.label] = ("row",) if timefreq_axis == 0 else ("chan",) # else 2D value better match expected shape else: assert value.shape == shape, f"{axis.mapper.fullname}: shape {value.shape}, expected {shape}" shapes[axis.label] = ("row", "chan") arrays[axis.label] = value # any new data generated for this correlation? Make dataframe if num_points: total_num_points += num_points args = (v for pair in ((array, shapes[key]) for key, array in arrays.items()) for v in pair) df1 = dataframe_factory(("row", "chan"), *args, columns=arrays.keys()) # if any axis needs to be conjugated, double up all of them if not noconj and any([axis.conjugate for axis in DataAxis.all_axes.values()]): arr_shape = [(-arrays[axis.label] if axis.conjugate else arrays[axis.label], shapes[axis.label]) for axis in DataAxis.all_axes.values()] args = (v for pair in arr_shape for v in pair) df2 = dataframe_factory(("row", "chan"), *args, columns=arrays.keys()) df1 = dask_df.concat([df1, df2], axis=0) ddf = dask_df.concat([ddf, df1], axis=0) if ddf is not None else df1 # do we already have a frame for this key ddf0 = output_dataframes.get(dataframe_key) if ddf0 is None: log.debug(f"first frame for {dataframe_key}") output_dataframes[dataframe_key] = ddf else: log.debug(f"appending to frame for {dataframe_key}") output_dataframes[dataframe_key] = dask_df.concat([ddf0, ddf], axis=0) # convert discrete axes into categoricals if data_mappers.USE_COUNT_CAT: categorical_axes = [axis.label for axis in DataAxis.all_axes.values() if axis.nlevels] if categorical_axes: log.info(": counting colours") for key, ddf in list(output_dataframes.items()): output_dataframes[key] = ddf.categorize(categorical_axes) # print("===") # for ddf in output_dataframes.values(): # for axis in DataAxis.all_axes.values(): # value = ddf[axis.label].values.compute() # print(axis.label, np.nanmin(value), np.nanmax(value)) log.info(": complete") return output_dataframes, output_subsets, total_num_points
def create_plot(ddf, index_subsets, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, normalize, xlabel, ylabel, title, pngname, extra_markup=[], min_alpha=40, saturate_percentile=None, saturate_alpha=None, minmax_cache=None, options=None): figx = options.xcanvas / 60 figy = options.ycanvas / 60 bgcol = "#" + options.bgcol.lstrip("#") xaxis = xdatum.label yaxis = ydatum.label aaxis = adatum and adatum.label caxis = cdatum and cdatum.label color_key = color_labels = color_minmax = agg_alpha = None # do we need to compute any axis min/max? bounds = OrderedDict({xaxis: xdatum.minmax, yaxis: ydatum.minmax}) unknown = [] for datum in xdatum, ydatum, cdatum: if datum is not None: bounds[datum.label] = datum.minmax if datum.minmax[0] is None or datum.minmax[1] is None: if datum.is_discrete and datum.subset_indices is not None: bounds[datum.label] = 0, len(datum.subset_indices)-1 else: unknown.append(datum.label) if unknown: log.info(f": scanning axis min/max for {' '.join(unknown)}") compute_bounds(unknown, bounds, ddf) # populate cache if minmax_cache is not None: minmax_cache.update([(label, bounds[label]) for label in unknown]) # adjust bounds for discrete axes canvas_sizes = [] for datum, size in (xdatum, options.xcanvas), (ydatum, options.ycanvas): if datum.is_discrete: bounds[datum.label] = bounds[datum.label][0]-0.5, bounds[datum.label][1]+0.5 size = int(bounds[datum.label][1]) - int(bounds[datum.label][0]) + 1 canvas_sizes.append(size) # create rendering canvas. canvas = datashader.Canvas(canvas_sizes[0], canvas_sizes[1], x_range=bounds[xaxis], y_range=bounds[yaxis]) if aaxis is not None: agg_alpha = getattr(datashader.reductions, ared, None) if ared else datashader.reductions.count if agg_alpha is None: raise ValueError(f"unknown alpha reduction function {ared}") agg_alpha = agg_alpha(aaxis) if cdatum is not None: # aggregation applied to by() agg_by = agg_alpha if agg_alpha else datashader.count() # figure out mapping from raster planes to colours # after this if-else block, category will be an aggregator instance yielding N categories, # color_key will be a list of N colors, and color_label will be a list of N textual labels if data_mappers.USE_COUNT_CAT: cats = getattr(ddf.dtypes, caxis).categories log.debug(f'colourizing using {caxis} categorical, {len(cats)} bins') category = caxis color_key = dmap[:len(cats)] color_labels = list(map(str, cats)) else: if cdatum.is_discrete: # make dictionary from index to label, omitting values that are not in the MS subset to begin with if cdatum.discretized_labels: active_subset = OrderedDict(enumerate(cdatum.discretized_labels)) # else make up integer labels on the spot else: active_subset = OrderedDict(enumerate(map(str, range(bounds[caxis][1]+1)))) # Check if the subset needs to be refined, because it is known to be smaller for this dataframe if len(cdatum.columns) == 1 and cdatum.columns[0] in index_subsets: df_index_subset = index_subsets[cdatum.columns[0]] if cdatum.subset_remapper is not None: remapper = cdatum.subset_remapper.compute() df_index_subset = set(remapper[x] for x in df_index_subset) active_subset = OrderedDict((idx, active_subset[idx]) for idx in df_index_subset) log.debug(f"subset of indices for this axis is a priori {list(active_subset.keys())}") # max known index max_index = max(active_subset.keys()) num_colors = min(cdatum.nlevels, len(dmap)) color_key = dmap[:num_colors] # if we have fewer indices than colour levels, and the max index is sensible, we'll aggregate to one # raster slice per index value directly if len(active_subset) <= num_colors and max_index < max(num_colors, 256): num_colors = max_index+1 log.debug(f"aggregating directly into {max_index+1} categories") category = category_modulo(caxis, max_index+1) color_label_list = {idx: [value] for idx, value in active_subset.items()} else: log.debug(f"aggregating modulo {num_colors} categories") category = category_modulo(caxis, num_colors) # each slice maps to, potentially, multiple labels from the subset color_label_list = {i: [active_subset[idx] for idx in range(i, max_index+1, num_colors) if idx in active_subset] for i in range(num_colors)} # and colors just come from the bottom of the colormap color_dict = dict(enumerate(options.dmap[:num_colors])) # convert lists of color labels into strings color_labels = ['']*num_colors for i, labels in color_label_list.items(): if len(labels) < 3: color_labels[i] = ",".join(labels) else: color_labels[i] = ",".join(labels[:2] + ["..."]) # else we discretize a span of values else: num_colors = min(cdatum.nlevels, len(bmap)) log.debug(f'colourizing using {caxis} with {num_colors} bins') cmin = bounds[caxis][0] cdelta = (bounds[caxis][1] - cmin) / num_colors category = category_binning(caxis, cmin, cdelta, num_colors) # color labels are bin centres bin_centers = [cmin + cdelta*(i+0.5) for i in range(num_colors)] # map to colors pulled from entire extent of color map color_key = [bmap[(i*len(bmap))//num_colors] for i in range(num_colors)] color_labels = [str(bin) for bin in bin_centers] log.info(f": aggregating using {num_colors} bins at {' '.join(color_labels)})") raster = canvas.points(ddf, xaxis, yaxis, agg=datashader.by(category, agg_by)) is_integer_raster = np.issubdtype(raster.dtype, np.integer) # the binning aggregator accumulates flagged points in an extra raster plane if isinstance(category, category_binning): if is_integer_raster: log.info(f": {raster[..., -1].data.sum():.3g} points were flagged ") raster = raster[...,:-1] if is_integer_raster: non_empty = np.array(raster.any(axis=(0, 1))) else: non_empty = ~(np.isnan(raster.data).all(axis=(0, 1))) if not non_empty.any(): log.info(": no valid data in plot. Check your flags and/or plot limits.") return None if cdatum.is_discrete and not data_mappers.USE_COUNT_CAT: # discard empty planes non_empty = np.where(non_empty)[0] raster = raster[..., non_empty] # compress colours to bottom of colormap, unless asked to preserve assignments if options.dmap_preserve: color_key = [color_key[bin] for bin in non_empty] else: color_key = color_key[:len(non_empty)] color_labels = [color_labels[bin] for bin in non_empty] img = datashader.transfer_functions.shade(raster, color_key=color_key, how=normalize, min_alpha=min_alpha) # set color_minmax for colorbar color_minmax = bounds[caxis] else: log.debug(f'rasterizing using {ared}') raster = canvas.points(ddf, xaxis, yaxis, agg=agg_alpha) if not raster.data.any(): log.info(": no valid data in plot. Check your flags and/or plot limits.") return None # get min/max cor colorbar if aaxis: amin, amax = adatum.minmax color_minmax = (amin if amin is not None else np.nanmin(raster)), \ (amax if amax is not None else np.nanmax(raster)) color_key = cmap log.debug('shading') img = datashader.transfer_functions.shade(raster, cmap=cmap, how=normalize, span=color_minmax, min_alpha=min_alpha) # resaturate if needed if saturate_alpha is not None or saturate_percentile is not None: # get alpha channel imgval = img.values alpha = (imgval >> 24)&255 nulls = alpha<min_alpha alpha -= min_alpha if nulls.all(): log.debug(f"alpha<min_alpha for entire plot -- all data below lower clip perhaps?") else: #if percentile if specified, use that to override saturate_alpha if saturate_alpha is None: saturate_alpha = np.percentile(alpha[~nulls], saturate_percentile) log.debug(f"using saturation alpha {saturate_alpha} from {saturate_percentile}th percentile") else: log.debug(f"using explicit saturation alpha {saturate_alpha}") # rescale alpha from [min_alpha, saturation_alpha] to [min_alpha, 255] saturation_factor = (255. - min_alpha) / (saturate_alpha - min_alpha) alpha = min_alpha + alpha*saturation_factor alpha[nulls] = 0 alpha[alpha>255] = 255 imgval[:] = (imgval & 0xFFFFFF) | alpha.astype(np.uint32)<<24 if options.spread_pix: img = datashader.transfer_functions.dynspread(img, options.spread_thr, max_px=options.spread_pix) log.info(f": spreading ({options.spread_thr} {options.spread_pix})") rgb = holoviews.RGB(holoviews.operation.datashader.shade.uint32_to_uint8_xr(img)) log.debug('done') # Set plot limits based on data extent or user values for axis labels xmin, xmax = bounds[xaxis] ymin, ymax = bounds[yaxis] log.debug('rendering image') fig = pylab.figure(figsize=(figx, figy)) ax = fig.add_subplot(111, facecolor=bgcol) for funcname, args, kwargs in extra_markup: getattr(ax, funcname)(*args, **kwargs) ax.imshow(X=rgb.data, extent=[xmin, xmax, ymin, ymax], aspect='auto', origin='lower', interpolation='nearest') ax.set_title("\n".join(textwrap.wrap(title, 90)), loc='center', fontdict=dict(fontsize=options.fontsize)) ax.set_xlabel(xlabel, fontdict=dict(fontsize=options.fontsize)) ax.set_ylabel(ylabel, fontdict=dict(fontsize=options.fontsize)) # ax.plot(xmin,ymin,'.',alpha=0.0) # ax.plot(xmax,ymax,'.',alpha=0.0) dx, dy = xmax - xmin, ymax - ymin ax.set_xlim([xmin - dx/100, xmax + dx/100]) ax.set_ylim([ymin - dy/100, ymax + dy/100]) def decimate_list(x, maxel): """Helper function to reduce a list to < given max number of elements, dividing it by decimal factors of 2 and 5""" factors = 2, 5, 10 base = divisor = 1 while len(x)//divisor > maxel: for fac in factors: divisor = fac*base if len(x)//divisor <= maxel: break base *= 10 return x[::divisor] ax.tick_params(labelsize=options.fontsize*0.66) # max # of tickmarks and labels to draw for discrete axes MAXLABELS = 64 # if we have up to this many labels, show them all MAXLABELS1 = 32 # if we have >MAXLABELS to show, then sparsify and get below this number MAXTICKS = 300 # if total number of points is within this range, draw them as minor tickmarks # do we have discrete labels to put on the axes? if xdatum.discretized_labels is not None: n = len(xdatum.discretized_labels) ticks_labels = list(enumerate(xdatum.discretized_labels)) if n > MAXLABELS: ticks_labels = decimate_list(ticks_labels, MAXLABELS1) # enforce max number of tick labels labels = [label for _, label in ticks_labels] rot = 90 if max([len(label) for label in xdatum.discretized_labels])*n > 60 else 0 ax.set_xticks([x[0] for x in ticks_labels]) ax.set_xticklabels(labels, rotation=rot) if len(ticks_labels) < n and n <= MAXTICKS: ax.set_xticks(range(n), minor=True) if ydatum.discretized_labels is not None: n = len(ydatum.discretized_labels) ticks_labels = list(enumerate(ydatum.discretized_labels)) if n > MAXLABELS: ticks_labels = decimate_list(ticks_labels, MAXLABELS1) # enforce max number of tick labels labels = [label for _, label in ticks_labels] ax.set_yticks([y[0] for y in ticks_labels]) ax.set_yticklabels(labels) if len(ticks_labels) < n and n <= MAXTICKS: ax.set_yticks(range(n), minor=True) # colorbar? if color_minmax: import matplotlib.colors # discrete axis if caxis is not None and cdatum.is_discrete: norm = matplotlib.colors.Normalize(-0.5, len(color_key)-0.5) ticks = np.arange(len(color_key)) colormap = matplotlib.colors.ListedColormap(color_key) # discretized axis else: norm = matplotlib.colors.Normalize(*color_minmax) colormap = matplotlib.colors.ListedColormap(color_key) # auto-mark colorbar, since it represents a continuous range of values ticks = None cb = fig.colorbar(matplotlib.cm.ScalarMappable(norm=norm, cmap=colormap), ax=ax, ticks=ticks) # adjust ticks for discrete axis if caxis is not None and cdatum.is_discrete: rot = 0 # adjust fontsize for number of labels fs = max(options.fontsize*min(0.8, 20./len(color_labels)), 6) fontdict = dict(fontsize=fs) if max([len(lbl) for lbl in color_labels]) > 3 and len(color_labels) < 8: rot = 90 fontdict['verticalalignment'] ='center' cb.ax.set_yticklabels(color_labels, rotation=rot, fontdict=fontdict) else: cb.ax.tick_params(labelsize=options.fontsize*0.8) fig.savefig(pngname, bbox_inches='tight') pylab.close() return pngname
def __init__(self, column, function, corr, ms, minmax=None, ncol=None, label=None, subset=None): """See register() class method above. Not called directly.""" self.name = ":".join([ str(x) for x in (function, column, corr, minmax, ncol) if x is not None ]) self.ms = ms self.function = function # function to apply to column (see list of DataMappers below) self.corr = corr if corr != "all" else None self.nlevels = ncol self.minmax = tuple(minmax) if minmax is not None else (None, None) self._minmax_autorange = (self.minmax == (None, None)) self.label = label self._corr_reduce = None self._is_discrete = None self.mapper = data_mappers[function] # if set, axis is discrete and labelled self.discretized_labels = None # for discrete axes: if a subset of N indices is explicitly selected for plotting, then this # is a list of the selected indices, of length N self.subset_indices = None # ...and this is a dask array that maps selected indices into bins 0...N-1, and all other values into bin N self.subset_remapper = None # ...and this is the maximum valid index in MS maxind = None # columns with labels? if function == 'CORR' or function == 'STOKES': corrset = set(subset.corr.names) corrset1 = corrset - set("IQUV") if corrset1 == corrset: name = "Correlation" elif not corrset1: name = "Stokes" else: name = "Correlation or Stokes" # special case of "corr" if corr is fixed: return constant value fixed here # When corr is fixed, we're creating a mapper for a know correlation. When corr is not fixed, # we're creating one for a mapper that will iterate over correlations if corr is not None: self.mapper = DataMapper(name, "", column=False, axis=-1, mapper=lambda x: corr) self.subset_indices = subset.corr maxind = ms.all_corr.numbers[-1] elif column == "FIELD_ID": self.subset_indices = subset.field maxind = ms.field.numbers[-1] elif column == "ANTENNA1" or column == "ANTENNA2": self.subset_indices = subset.ant maxind = ms.antenna.numbers[-1] elif column == "SCAN_NUMBER": self.subset_indices = subset.scan maxind = ms.scan.numbers[-1] elif function == "BASELINE": self.subset_indices = subset.baseline maxind = ms.baseline.numbers[-1] elif function == "BASELINE_M": bl_subset = set(subset.baseline.numbers) # active baselines numbers = [i for i in ms.baseline_m.numbers if i in bl_subset] names = [ bl for i, bl in zip(ms.baseline_m.numbers, ms.baseline_m.names) if i in bl_subset ] self.subset_indices = ms_info.NamedList("baseline_m", names, numbers) maxind = ms.baseline.numbers[-1] elif column == "FLAG" or column == "FLAG_ROW": self.discretized_labels = ["F", "T"] # make a remapper if self.subset_indices is not None: # if the mapping from indices to bins 1:1? subind = np.array(self.subset_indices.numbers) identity = subind[0] == 0 and ((subind[1:] - subind[:-1]) == 1).all() # If mapping is not 1:1, or subset is short of full set, then we need a remapper. # Map indices in subset into their ordinal numbers in the subset (0...N-1), and all other indices to N if len(self.subset_indices) < maxind + 1 or not identity: remapper = np.full(maxind + 1, len(self.subset_indices)) for i, index in enumerate(self.subset_indices.numbers): remapper[index] = i self.subset_remapper = da.array(remapper) self.discretized_labels = self.subset_indices.names self.subset_indices = self.subset_indices.numbers if self.discretized_labels: self._is_discrete = True # axis name self.fullname = self.mapper.fullname or column or '' # if labels were set up, adjust nlevels (but only down, never up) if self.discretized_labels is not None and self.nlevels is not None: self.nlevels = min(len(self.discretized_labels), self.nlevels) if self.function == "_": self.function = "" self.conjugate = self.mapper.conjugate self.timefreq_axis = self.mapper.axis # setup columns self._ufunc = None self.columns = () # does the mapper have no column (i.e. frequency)? if self.mapper.column is False: log.info( f'axis: {function}, range {self.minmax}, discretization {self.nlevels}' ) # does the mapper have a fixed column? This better be consistent elif self.mapper.column is not None: # if a mapper (such as "uv") implies a fixed column name, make sure it's consistent with what the user said if column and self.mapper.column != column: raise ValueError( f"'{function}' not applicable with column {column}") self.columns = (column, ) log.info( f'axis: {function}({column}), range {self.minmax}, discretization {self.nlevels}' ) # else arbitrary column else: log.info( f'axis: {function}({column}), corr {self.corr}, range {self.minmax}, discretization {self.nlevels}' ) # check for column arithmetic match = re.fullmatch(r"(\w+)([*/+-])(\w+)", column) if match: self.columns = (match.group(1), match.group(3)) # look up dask ufunc corresponding to arithmetic op self._ufunc = { '+': da.add, '*': da.multiply, '-': da.subtract, '/': da.divide }[match.group(2)] else: self.columns = (column, )