Exemplo n.º 1
0
    def register(cls,
                 function,
                 column,
                 corr,
                 ms,
                 minmax=None,
                 ncol=None,
                 subset=None,
                 minmax_cache=None):
        """
        Registers a data axis, which ultimately ends up as a column in the assembled dataframe.
        For multiple plots, we want to reuse the same information (assuming the same
        clipping limits, etc.) so we have a dictionary of axis definitions here.

        Columns selects a column to operate on.
        Function selects a mapping (see datamappers below).
        Corr selects a correlation (or a Stokes product such as I, Q,...)
        minmax sets axis clipping levels
        ncol discretizes the axis into N colours between min and max
        minmax_cache provides a dict of cached min/max values, which will be looked up via the label, if minmax
                     is not explicitly set
        """
        # form up label
        label = "{}_{}_{}".format(col_to_label(column or ''), function, corr)
        minmax = tuple(minmax) if minmax is not None else (None, None)
        key = label, minmax, ncol
        # see if this axis definition already exists, else create new one
        if key in cls.all_axes:
            return cls.all_axes[key]
        else:
            # see if minmax should be loaded
            if (minmax is None or tuple(minmax) ==
                (None, None)) and minmax_cache and label in minmax_cache:
                log.info(f"loading {label} min/max from cache")
                minmax = minmax_cache[label]

            label0, i = label, 0
            while label in cls.all_labels:
                i += 1
                label = f"{label}_{i}"
            cls.all_labels.add(label)
            axis = cls.all_axes[key] = DataAxis(column,
                                                function,
                                                corr,
                                                ms,
                                                minmax,
                                                ncol,
                                                label,
                                                subset=subset)
            return axis
Exemplo n.º 2
0
 def render_single_plot(df, subset, xdatum, ydatum, adatum, ared, cdatum, pngname, title, xlabel, ylabel):
     """Renders a single plot. Make this a function since we might call it in parallel"""
     log.info(f": rendering {pngname}")
     normalize = options.norm
     if normalize == "auto":
         normalize = "log" if cdatum is not None else ("eq_hist" if adatum is None else 'linear')
     if options.profile:
         context = dask.diagnostics.ResourceProfiler
     else:
         context = nullcontext
     with context() as profiler:
         result = data_plots.create_plot(df, subset, xdatum, ydatum, adatum, ared, cdatum,
                                   cmap=cmap, bmap=bmap, dmap=dmap, normalize=normalize,
                                   min_alpha=options.min_alpha,
                                   saturate_alpha=options.saturate_alpha,
                                   saturate_percentile=options.saturate_perc,
                                   xlabel=xlabel, ylabel=ylabel, title=title, pngname=pngname,
                                   extra_markup=extra_markup,
                                   minmax_cache=minmax_cache,
                                   options=options)
     if result:
         log.info(f'                 : wrote {pngname}')
         if profiler is not None:
             profile_file = os.path.splitext(pngname)[0] + ".prof.html"
             dask.diagnostics.visualize(profiler, file_path=profile_file, show=False, save=True)
             log.info(f'                 : wrote profiler info to {profile_file}')
Exemplo n.º 3
0
def get_colormap(cmap_name):
    cmap = getattr(colorcet, cmap_name, None)
    if cmap:
        log.info(f"using colourmap colorcet.{cmap_name}")
        return cmap
    cmap = getattr(cmasher, cmap_name, None)
    if cmap:
        log.info(f"using colourmap cmasher.{cmap_name}")
    else:
        cmap = getattr(matplotlib.cm, cmap_name, None)
        if cmap is None:
            raise ValueError(f"unknown colourmap {cmap_name}")
        log.info(f"using colourmap matplotplib.cm.{cmap_name}")
    return [ f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}" for r,g,b in cmap.colors ]
Exemplo n.º 4
0
def get_plot_data(msinfo, group_cols, mytaql, chan_freqs,
                  chanslice, subset,
                  noflags, noconj,
                  iter_field, iter_spw, iter_scan,
                  join_corrs=False,
                  row_chunk_size=100000):

    ms_cols = {'ANTENNA1', 'ANTENNA2'}
    if not noflags:
        ms_cols.update({'FLAG', 'FLAG_ROW'})
    # get visibility columns
    for axis in DataAxis.all_axes.values():
        ms_cols.update(axis.columns)

    # get MS data
    msdata = daskms.xds_from_ms(msinfo.msname, columns=list(ms_cols), group_cols=group_cols, taql_where=mytaql,
                                chunks=dict(row=row_chunk_size))

    log.info(f': Indexing MS and building dataframes (chunk size is {row_chunk_size})')

    np = 0  # number of points to plot

    # output dataframes, indexed by (field, spw, scan, antenna, correlation)
    # If any of these axes is not being iterated over, then the index is None
    output_dataframes = OrderedDict()

    # # make prototype dataframe
    # import pandas
    #
    #

    # iterate over groups
    for group in msdata:
        ddid     =  group.DATA_DESC_ID  # always present
        fld      =  group.FIELD_ID # always present
        if fld not in subset.field or ddid not in subset.spw:
            log.debug(f"field {fld} ddid {ddid} not in selection, skipping")
            continue
        scan    = getattr(group, 'SCAN_NUMBER', None)  # will be present if iterating over scans

        # TODO: antenna iteration. None forces no iteration, for now
        antenna = None

        # always read flags -- easier that way
        flag = group.FLAG if not noflags else None
        flag_row = group.FLAG_ROW if not noflags else None


        baselines = group.ANTENNA1*len(msinfo.antenna) + group.ANTENNA2

        freqs = chan_freqs[ddid]
        chans = xarray.DataArray(range(len(freqs)), dims=("chan",))
        wavel = freq_to_wavel(freqs)
        extras = dict(chans=chans, freqs=freqs, wavel=wavel, rows=group.row, baselines=baselines)

        nchan = len(group.chan)
        if flag is not None:
            flag = flag[dict(chan=chanslice)]
            nchan = flag.shape[1]
        shape = (len(group.row), nchan)

        datums = OrderedDict()

        for corr in subset.corr.numbers:
            # make dictionary of extra values for DataMappers
            extras['corr'] = corr
            # loop over datums to be computed
            for axis in DataAxis.all_axes.values():
                value = datums[axis.label][-1] if axis.label in datums else None
                # a datum was already computed?
                if value is not None:
                    # if not joining correlations, then that's the only one we'll need, so continue
                    if not join_corrs:
                        continue
                    # joining correlations, and datum has a correlation dependence: compute another one
                    if axis.corr is None:
                        value = None
                if value is None:
                    value = axis.get_value(group, corr, extras, flag=flag, flag_row=flag_row, chanslice=chanslice)
                    # reshape values of shape NTIME to (NTIME,1) and NFREQ to (1,NFREQ), and scalar to (NTIME,1)
                    if value.ndim == 1:
                        timefreq_axis = axis.mapper.axis or 0
                        assert value.shape[0] == shape[timefreq_axis], \
                               f"{axis.mapper.fullname}: size {value.shape[0]}, expected {shape[timefreq_axis]}"
                        shape1 = [1,1]
                        shape1[timefreq_axis] = value.shape[0]
                        value = value.reshape(shape1)
                        if timefreq_axis > 0:
                            value = da.broadcast_to(value, shape)
                        log.debug(f"axis {axis.mapper.fullname} has shape {value.shape}")
                    # else 2D value better match expected shape
                    else:
                        assert value.shape == shape, f"{axis.mapper.fullname}: shape {value.shape}, expected {shape}"
                datums.setdefault(axis.label, []).append(value)

        # if joining correlations, stick all elements together. Otherwise, we'd better have one per label
        if join_corrs:
            datums = OrderedDict({label: da.concatenate(arrs) for label, arrs in datums.items()})
        else:
            assert all([len(arrs) == 1 for arrs in datums.values()])
            datums = OrderedDict({label: arrs[0] for label, arrs in datums.items()})

        # broadcast to same shape, and unravel all datums
        datums = OrderedDict({ key: arr.ravel() for key, arr in zip(datums.keys(),
                                                                    da.broadcast_arrays(*datums.values()))})

        # if any axis needs to be conjugated, double up all of them
        if not noconj and any([axis.conjugate for axis in DataAxis.all_axes.values()]):
            for axis in DataAxis.all_axes.values():
                if axis.conjugate:
                    datums[axis.label] = da.concatenate([datums[axis.label], -datums[axis.label]])
                else:
                    datums[axis.label] = da.concatenate([datums[axis.label], datums[axis.label]])

        labels, values = list(datums.keys()), list(datums.values())
        np += values[0].size

        # now stack them all into a big dataframe
        rectype = [(axis.label, numpy.int32 if axis.nlevels else numpy.float32) for axis in DataAxis.all_axes.values()]
        recarr = da.empty_like(values[0], dtype=rectype)
        ddf = dask_df.from_array(recarr)
        for label, value in zip(labels, values):
            ddf[label] = value

        # now, are we iterating or concatenating? Make frame key accordingly
        dataframe_key = (fld if iter_field else None,
                         ddid if iter_spw else None,
                         scan if iter_scan else None,
                         antenna)

        # do we already have a frame for this key
        ddf0 = output_dataframes.get(dataframe_key)

        if ddf0 is None:
            log.debug(f"first frame for {dataframe_key}")
            output_dataframes[dataframe_key] = ddf
        else:
            log.debug(f"appending to frame for {dataframe_key}")
            output_dataframes[dataframe_key] = ddf0.append(ddf)

    # convert discrete axes into categoricals
    if data_mappers.USE_COUNT_CAT:
        categorical_axes = [axis.label for axis in DataAxis.all_axes.values() if axis.nlevels]
        if categorical_axes:
            log.info(": counting colours")
            for key, ddf in list(output_dataframes.items()):
                output_dataframes[key] = ddf.categorize(categorical_axes)

    log.info(": complete")
    return output_dataframes, np
Exemplo n.º 5
0
def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, normalize,
                xlabel, ylabel, title, pngname,
                options=None):

    figx = options.xcanvas / 60
    figy = options.ycanvas / 60
    bgcol = "#" + options.bgcol.lstrip("#")

    xaxis = xdatum.label
    yaxis = ydatum.label
    aaxis = adatum and adatum.label
    caxis = cdatum and cdatum.label
    color_key = ncolors = color_mapping = color_labels = agg_alpha = raster_alpha = None

    xmin, xmax = xdatum.minmax
    ymin, ymax = ydatum.minmax

    canvas = datashader.Canvas(options.xcanvas, options.ycanvas,
                               x_range=[xmin, xmax] if xmin is not None else None,
                               y_range=[ymin, ymax] if ymin is not None else None)

    if aaxis is not None:
        agg_alpha = getattr(datashader.reductions, ared, None)
        if agg_alpha is None:
            raise ValueError(f"unknown alpha reduction function {ared}")
        agg_alpha = agg_alpha(aaxis)
    ared = ared or 'count'

    if cdatum is not None:
        if agg_alpha is not None and not USE_REDUCE_BY:
            log.debug(f'rasterizing alpha channel using {ared}(aaxis)')
            raster_alpha = canvas.points(ddf, xaxis, yaxis, agg=agg_alpha)

        if data_mappers.USE_COUNT_CAT:
            color_bins = [int(x) for x in getattr(ddf.dtypes, caxis).categories]
            log.debug(f'colourizing with count_cat, {len(color_bins)} bins')
            if USE_REDUCE_BY and agg_alpha:
                agg = datashader.by(caxis, agg_alpha)
            else:
                agg = datashader.count_cat(caxis)
        else:
            color_bins = list(range(cdatum.nlevels))
            log.debug(f'colourizing with count_integer, {len(color_bins)} bins')
            if USE_REDUCE_BY and agg_alpha:
                agg = by_integers(caxis, agg_alpha, cdatum.nlevels)
            else:
                agg = count_integers(caxis, cdatum.nlevels)


        raster = canvas.points(ddf, xaxis, yaxis, agg=agg)
        non_empty = numpy.array(raster.any(axis=(0, 1)))
        if not non_empty.any():
            log.info(": no valid data in plot. Check your flags and/or plot limits.")
            return None
        # true if axis is continuous discretized
        if cdatum.discretized_delta is not None:
            # color labels are bin centres
            bin_centers = [cdatum.discretized_bin_centers[i] for i in color_bins]
            # map to colors pulled from 256 color map
            color_key = [bmap[(i*256)//cdatum.nlevels] for i in color_bins]
            color_labels = list(map(str, bin_centers))
            log.info(f": shading using {len(color_bins)} colors (bin centres are {' '.join(color_labels)})")
        # else a discrete axis
        else:
            # discard empty bins
            non_empty = numpy.where(non_empty)[0]
            raster = raster[..., non_empty]
            # just use bin numbers to look up a color directly
            color_bins = [color_bins[i] for i in non_empty]
            color_key = [dmap[bin] for bin in color_bins]
            # the numbers may be out of order -- reorder for color bar purposes
            bin_color = sorted(zip(color_bins, color_key))
            if cdatum.discretized_labels and len(cdatum.discretized_labels) <= cdatum.nlevels:
                color_labels = [cdatum.discretized_labels[bin] for bin, _ in bin_color]
            else:
                color_labels = [str(bin) for bin, _ in bin_color]
            color_mapping = [col for _, col in bin_color]
            log.info(f": rendering using {len(color_bins)} colors (values {' '.join(color_labels)})")
        if raster_alpha is not None:
            amin, amax = numpy.nanmin(raster_alpha), numpy.nanmax(raster_alpha)
            raster = raster*(raster_alpha-amin)/(amax-amin)
            log.info(f": adjusting alpha (alpha raster was {amin} to {amax})")
        img = datashader.transfer_functions.shade(raster, color_key=color_key, how=normalize)
    else:
        log.debug(f'rasterizing using {ared}')
        raster = canvas.points(ddf, xaxis, yaxis, agg=agg_alpha)
        if not raster.data.any():
            log.info(": no valid data in plot. Check your flags and/or plot limits.")
            return None
        log.debug('shading')
        img = datashader.transfer_functions.shade(raster, cmap=cmap, how=normalize)

    if options.spread_pix:
        img = datashader.transfer_functions.dynspread(img, options.spread_thr, max_px=options.spread_pix)
        log.info(f": spreading ({options.spread_thr} {options.spread_pix})")
    rgb = holoviews.RGB(holoviews.operation.datashader.shade.uint32_to_uint8_xr(img))

    log.debug('done')

    # Set plot limits based on data extent or user values for axis labels

    data_xmin = numpy.min(raster.coords[xaxis].values)
    data_xmax = numpy.max(raster.coords[xaxis].values)
    data_ymin = numpy.min(raster.coords[yaxis].values)
    data_ymax = numpy.max(raster.coords[yaxis].values)

    xmin = data_xmin if xmin is None else xdatum.minmax[0]
    xmax = data_xmax if xmax is None else xdatum.minmax[1]
    ymin = data_ymin if ymin is None else ydatum.minmax[0]
    ymax = data_ymax if ymax is None else ydatum.minmax[1]

    log.debug('rendering image')

    def match(artist):
        return artist.__module__ == 'matplotlib.text'

    fig = pylab.figure(figsize=(figx, figy))
    ax = fig.add_subplot(111, facecolor=bgcol)
    ax.imshow(X=rgb.data, extent=[data_xmin, data_xmax, data_ymin, data_ymax],
              aspect='auto', origin='lower')
    ax.set_title("\n".join(textwrap.wrap(title, 90)), loc='left')
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    # ax.plot(xmin,ymin,'.',alpha=0.0)
    # ax.plot(xmax,ymax,'.',alpha=0.0)

    dx, dy = xmax - xmin, ymax - ymin
    ax.set_xlim([xmin - dx/100, xmax + dx/100])
    ax.set_ylim([ymin - dy/100, ymax + dy/100])

    # set fontsize on everything rendered so far
    for textobj in fig.findobj(match=match):
        textobj.set_fontsize(options.fontsize)

    # colorbar?
    if color_key:
        import matplotlib.colors
        # discrete axis
        if color_mapping is not None:
            norm = matplotlib.colors.Normalize(-0.5, len(color_bins)-0.5)
            ticks = numpy.arange(len(color_bins))
            colormap = matplotlib.colors.ListedColormap(color_mapping)
        # discretized axis
        else:
            norm = matplotlib.colors.Normalize(cdatum.minmax[0], cdatum.minmax[1])
            colormap = matplotlib.colors.ListedColormap(color_key)
            # auto-mark colorbar, since it represents a continuous range of values
            ticks = None

        cb = fig.colorbar(matplotlib.cm.ScalarMappable(norm=norm, cmap=colormap), ax=ax, ticks=ticks)

        # adjust ticks for discrete axis
        if color_mapping is not None:
            rot = 0
            # adjust fontsize for number of labels
            fs = max(options.fontsize*min(1, 32./len(color_labels)), 6)
            fontdict = dict(fontsize=fs)
            if max([len(lbl) for lbl in color_labels]) > 3 and len(color_labels) < 8:
                rot = 90
                fontdict['verticalalignment'] ='center'
            cb.ax.set_yticklabels(color_labels, rotation=rot, fontdict=fontdict)

    fig.savefig(pngname, bbox_inches='tight')

    pylab.close()

    return pngname
Exemplo n.º 6
0
def main(argv):


    clock_start = time.time()

    # default # of CPUs


    # ---------------------------------------------------------------------------------------------------------------------------------------------


    parser = argparse.ArgumentParser(description='Rapid Measurement Set plotting with dask-ms and datashader. Version {0:s}'.format(__version__))


    parser.add_argument('ms', 
                      help='Measurement set')
    parser.add_argument("-v", "--version", action='version',
                      version='{:s} version {:s}'.format(parser.prog, __version__))

    group_opts = parser.add_argument_group('Plot types and data sources')

    group_opts.add_argument('-x', '--xaxis', dest='xaxis', action="append",
                      help="""X axis of plot, e.g. "amp:CORRECTED_DATA" This recognizes all column names (also CHAN, FREQ, 
                      CORR, ROW, WAVEL, U, V, W, UV), and, for complex columns, keywords such as 'amp', 'phase', 'real', 'imag'. You can also 
                      specify correlations, e.g. 'DATA:phase:XX', and do two-column arithmetic with "+-*/", e.g. 
                      'DATA-MODEL_DATA:amp'. Correlations may be specified by label, number, or as a Stokes parameter.
                      The order of specifiers does not matter.
                      """)

    group_opts.add_argument('-y', '--yaxis', dest='yaxis', action="append",
                      help="""Y axis to plot. Must be given the same number of times as --xaxis. Note that X/Y can
                      employ different columns and correlations.""")

    group_opts.add_argument('-a', '--aaxis', action="append",
                      help="""Intensity axis. Can be none, or given once, or given the same number of times as --xaxis.
                      If none, plot intensity (a.k.a. alpha channel) is proportional to density of points. Otherwise,
                      a reduction function (see --ared below) is applied to the given values, and the result is used
                      to determine intensity.
                      """)

    group_opts.add_argument('--ared', action="append",
                      help="""Alpha axis reduction function. Recognized reductions are count, any, sum, min, max,
                      mean, std, first, last, mode. Default is mean.""")

    group_opts.add_argument('-c', '--colour-by', action="append",
                      help="""Colour axis. Can be none, or given once, or given the same number of times as --xaxis.
                      All columns and variations listed under --xaxis are available for colouring by.""")

    group_opts.add_argument('-C', '--col', metavar="COLUMN", dest='col', action="append", default=[],
                      help="""Name of visibility column (default is DATA), if needed. This is used if
                      the axis specifications do not explicitly include a column. For multiple plots,
                      this can be given multiple times, or as a comma-separated list. Two-column arithmetic is recognized.
                      """)

    group_opts.add_argument('--noflags',
                      help='Enable to ignore flags. Default is to omit flagged data.', action='store_true')
    group_opts.add_argument('--noconj',
                      help='Do not show conjugate points in u,v plots (default = plot conjugates).', action='store_true')

    group_opts = parser.add_argument_group('Plot axes setup')

    group_opts.add_argument('--xmin', action='append',
                      help="""Minimum x-axis value (default = data min). For multiple plots, you can give this 
                      multiple times, or use a comma-separated list, but note that the clipping is the same per axis 
                      across all plots, so only the last applicable setting will be used. The list may include empty
                      elements (or 'None') to not apply a clip.""")
    group_opts.add_argument('--xmax', action='append',
                      help='Maximum x-axis value (default = data max).')
    group_opts.add_argument('--ymin', action='append',
                      help='Minimum y-axis value (default = data min).')
    group_opts.add_argument('--ymax', action='append',
                      help='Maximum y-axis value (default = data max).')
    group_opts.add_argument('--cmin', action='append',
                      help='Minimum colouring value. Must be supplied for every non-discrete axis to be coloured by.')
    group_opts.add_argument('--cmax', action='append',
                      help='Maximum colouring value. Must be supplied for every non-discrete axis to be coloured by.')
    group_opts.add_argument('--cnum', action='append',
                      help=f'Number of steps used to discretize a continuous axis. Default is {DEFAULT_CNUM}.')

    group_opts = parser.add_argument_group('Options for multiple plots or combined plots')

    group_opts.add_argument('--iter-field', action="store_true",
                      help='Separate plots per field (default is to combine in one plot)')
    group_opts.add_argument('--iter-antenna', action="store_true",
                      help='Separate plots per antenna (default is to combine in one plot)')
    group_opts.add_argument('--iter-spw', action="store_true",
                      help='Separate plots per spw (default is to combine in one plot)')
    group_opts.add_argument('--iter-scan', action="store_true",
                      help='Separate plots per scan (default is to combine in one plot)')
    group_opts.add_argument('--iter-corr', action="store_true",
                      help='Separate plots per correlation or Stokes (default is to combine in one plot)')

    group_opts = parser.add_argument_group('Data subset selection')

    group_opts.add_argument('--ant', default='all',
                      help='Antennas to plot (comma-separated list of names, default = all)')
    group_opts.add_argument('--ant-num',
                      help='Antennas to plot (comma-separated list of numbers, or a [start]:[stop][:step] slice, overrides --ant)')
    group_opts.add_argument('--baseline', default='all',
                      help="Baselines to plot, as 'ant1-ant2' (comma-separated list, default = all)")
    group_opts.add_argument('--spw', default='all',
                      help='Spectral windows (DDIDs) to plot (comma-separated list, default = all)')
    group_opts.add_argument('--field', default='all',
                      help='Field ID(s) to plot (comma-separated list, default = all)')
    group_opts.add_argument('--scan', default='all',
                      help='Scans to plot (comma-separated list, default = all)')
    group_opts.add_argument('--corr',  default='all',
                      help='Correlations or Stokes to plot, use indices or labels (comma-separated list, default = all)')
    group_opts.add_argument('--chan',
                      help='Channel slice, as [start]:[stop][:step], default is to plot all channels')

    group_opts = parser.add_argument_group('Rendering settings')

    group_opts.add_argument('-X', '--xcanvas', type=int,
                      help='Canvas x-size in pixels (default = %(default)s)', default=1280)
    group_opts.add_argument('-Y', '--ycanvas', type=int,
                      help='Canvas y-size in pixels (default = %(default)s)', default=900)
    group_opts.add_argument('--norm', choices=['auto', 'eq_hist', 'cbrt', 'log', 'linear'], default='auto',
                      help="Pixel scale normalization (default is 'log' when colouring, and 'eq_hist' when not)")
    group_opts.add_argument('--cmap', default='bkr',
                      help="""Colorcet map used without --colour-by  (default = %(default)s), see
                      https://colorcet.holoviz.org""")
    group_opts.add_argument('--bmap', default='bkr',
                      help='Colorcet map used when colouring by a continuous axis (default = %(default)s)')
    group_opts.add_argument('--dmap', default='glasbey_dark',
                      help='Colorcet map used when colouring by a discrete axis (default = %(default)s)')
    group_opts.add_argument('--spread-pix', type=int, default=0, metavar="PIX",
                      help="""Dynamically spread rendered pixels to this size""")
    group_opts.add_argument('--spread-thr', type=float, default=0.5, metavar="THR",
                      help="""Threshold parameter for spreading (0 to 1, default %(default)s)""")
    group_opts.add_argument('--bgcol', dest='bgcol',
                      help='RGB hex code for background colour (default = FFFFFF)', default='FFFFFF')
    group_opts.add_argument('--fontsize', dest='fontsize',
                      help='Font size for all text elements (default = 20)', default=20)

    group_opts = parser.add_argument_group('Output settings')

    # can also use "plot-{msbase}-{column}-{corr}-{xfullname}-vs-{yfullname}", let's expand on this later
    group_opts.add_argument('--dir',
                      help='Send all plots to this output directory')
    group_opts.add_argument('-s', '--suffix', help="suffix to be included in filenames, can include {options}")
    group_opts.add_argument('--png', dest='pngname',
                             default="plot-{ms}{_field}{_Spw}{_Scan}{_Ant}-{label}{_alphalabel}{_colorlabel}{_suffix}.png",
                      help='Template for output png files, default "%(default)s"')
    group_opts.add_argument('--title',
                             default="{ms}{_field}{_Spw}{_Scan}{_Ant}{_title}{_Alphatitle}{_Colortitle}",
                      help='Template for plot titles, default "%(default)s"')
    group_opts.add_argument('--xlabel',
                             default="{xname}{_xunit}",
                      help='Template for X axis labels, default "%(default)s"')
    group_opts.add_argument('--ylabel',
                             default="{yname}{_yunit}",
                             help='Template for X axis labels, default "%(default)s"')

    group_opts = parser.add_argument_group('Performance & tweaking')

    group_opts.add_argument("-d", "--debug", action='store_true',
                            help="Enable debugging output")
    group_opts.add_argument('-z', '--row-chunk-size', type=int, metavar="NROWS", default=100000,
                           help="""Row chunk size for dask-ms. Larger chunks may or may not be faster, but will
                            certainly use more RAM.""")
    group_opts.add_argument('-j', '--num-parallel', type=int, metavar="N", default=1,
                             help=f"""Run up to N renderers in parallel. Default is serial. Use -j0 to 
                             auto-set this to half the available cores ({DEFAULT_NUM_RENDERS} on this system).
                             This is not necessarily faster, as they might all end up contending for disk I/O. 
                             This might also work against dask-ms's own intrinsic parallelism. 
                             You have been advised.""")
    group_opts.add_argument("--profile", action='store_true', help="Enable dask profiling output")


    # various hidden performance-testing options
    data_mappers.add_options(group_opts)
    data_plots.add_options(group_opts)

    options = parser.parse_args(argv)

    cmap = getattr(colorcet, options.cmap, None)
    if cmap is None:
        parser.error(f"unknown --cmap {options.cmap}")
    bmap = getattr(colorcet, options.bmap, None)
    if bmap is None:
        parser.error(f"unknown --bmap {options.bmap}")
    dmap = getattr(colorcet, options.dmap, None)
    if dmap is None:
        parser.error(f"unknown --dmap {options.dmap}")

    options.ms = options.ms.rstrip('/')

    if options.debug:
        shade_ms.log_console_handler.setLevel(logging.DEBUG)

    # pass options to shade_ms
    data_mappers.set_options(options)
    data_plots.set_options(options)

    # figure our list of plots to make

    if not options.xaxis:
        xaxes = ['TIME'] # Default xaxis if none is specified
    else:
        xaxes = list(itertools.chain(*[opt.split(",") for opt in options.xaxis]))
    if not options.yaxis:
        yaxes = ['DATA:amp'] # Default yaxis if none is specified
    else:
        yaxes = list(itertools.chain(*[opt.split(",") for opt in options.yaxis]))

    if len(xaxes) != len(yaxes):
        parser.error("--xaxis and --yaxis must be given the same number of times")

    def get_conformal_list(name, force_type=None, default=None):
        """
        For all other settings, returns list same length as xaxes, or throws error if no conformance.
        Can also impose a type such as float (returning None for an empty string)
        """
        optlist = getattr(options, name, None)
        if not optlist:
            return [default]*len(xaxes)
        # stick all lists together
        elems = list(itertools.chain(*[opt.split(",") for opt in optlist]))
        if len(elems) > 1 and len(elems) != len(xaxes):
            parser.error(f"--{name} must be given the same number of times as --xaxis, or else just once")
        # convert type
        if force_type:
            elems = [force_type(x) if x and x.lower() != "none" else None for x in elems]
        if len(elems) != len(xaxes):
            elems = [elems[0]]*len(xaxes)
        return elems

    # get list of columns and plot limites of the same length
    if not options.col:
        options.col = ["DATA"]
    columns = get_conformal_list('col')
    xmins = get_conformal_list('xmin', float)
    xmaxs = get_conformal_list('xmax', float)
    ymins = get_conformal_list('ymin', float)
    ymaxs = get_conformal_list('ymax', float)
    aaxes = get_conformal_list('aaxis')
    areds = get_conformal_list('ared', str, 'mean')
    caxes = get_conformal_list('colour_by')
    cmins = get_conformal_list('cmin', float)
    cmaxs = get_conformal_list('cmax', float)
    cnums = get_conformal_list('cnum', int, default=DEFAULT_CNUM)

    # check min/max
    if any([(a is None)^(b is None) for a, b in zip(xmins, xmaxs)]):
        parser.error("--xmin/--xmax must be either both set, or neither")
    if any([(a is None)^(b is None) for a, b in zip(ymins, ymaxs)]):
        parser.error("--xmin/--xmax must be either both set, or neither")
    if any([(a is None)^(b is None) for a, b in zip(ymins, ymaxs)]):
        parser.error("--cmin/--cmax must be either both set, or neither")

    # check chan slice
    def parse_slice_spec(spec, name):
        if spec:
            try:
                spec_elems = [int(x) if x else None for x in spec.split(":", 2)]
            except ValueError:
                parser.error(f"invalid selection --{name} {spec}")
            return slice(*spec_elems), spec_elems
        else:
            return slice(None), []

    chanslice, chanslice_spec = parse_slice_spec(options.chan, name="chan")

    log.info(" ".join(sys.argv))

    blank()

    ms = MSInfo(options.ms, log=log)

    blank()
    log.info(": Data selected for plotting:")

    group_cols = ['FIELD_ID', 'DATA_DESC_ID']

    mytaql = []

    class Subset(object):
        pass
    subset = Subset()

    if options.ant != 'all' or options.ant_num:
        if options.ant_num:
            ant_subset = set()
            for spec in options.ant_num.split(","):
                if re.fullmatch(r"\d+", spec):
                    ant_subset.add(int(spec))
                else:
                    ant_subset.update(ms.all_antenna.numbers[parse_slice_spec(spec, "ant-num")[0]])
            subset.ant = ms.antenna.get_subset(sorted(ant_subset))
        else:
            subset.ant = ms.antenna.get_subset(options.ant)
        log.info(f"Antenna name(s)  : {' '.join(subset.ant.names)}")
        mytaql.append("||".join([f'ANTENNA1=={ant}||ANTENNA2=={ant}' for ant in subset.ant.numbers]))
    else:
        subset.ant = ms.antenna
        log.info('Antenna(s)       : all')

    if options.iter_antenna:
        raise NotImplementedError("iteration over antennas not currently supported")

    if options.baseline != 'all':
        subset.baseline = OrderedDict()
        for blspec in options.baseline.split(","):
            match = re.fullmatch(r"(\w+)-(\w+)", blspec)
            ant1 = match and ms.antenna[match.group(1)]
            ant2 = match and ms.antenna[match.group(2)]
            if ant1 is None or ant2 is None:
                raise ValueError("invalid baseline '{blspec}'")
            subset.baseline[blspec] = (ant1, ant2)
        # group_cols.append('ANTENNA1')
        log.info(f"Baseline(s)      : {' '.join(subset.baseline.keys())}")
        mytaql.append("||".join([f'(ANTENNA1=={ant1}&&ANTENNA2=={ant2})||(ANTENNA1=={ant2}&&ANTENNA2=={ant1})'
                                 for ant1, ant2 in subset.baseline.values()]))
    else:
        log.info('Baseline(s)      : all')

    if options.field != 'all':
        subset.field = ms.field.get_subset(options.field)
        log.info(f"Field(s)         : {' '.join(subset.field.names)}")
        # here for now, workaround for https://github.com/ska-sa/dask-ms/issues/100, should be inside if clause
        mytaql.append("||".join([f'FIELD_ID=={fld}' for fld in subset.field.numbers]))
    else:
        subset.field = ms.field
        log.info('Field(s)         : all')

    if options.spw != 'all':
        subset.spw = ms.spw.get_subset(options.spw)
        log.info(f"SPW(s)           : {' '.join(subset.spw.names)}")
        mytaql.append("||".join([f'DATA_DESC_ID=={ddid}' for ddid in subset.spw.numbers]))
    else:
        subset.spw = ms.spw
        log.info(f'SPW(s)           : all')

    if options.scan != 'all':
        subset.scan = ms.scan.get_subset(options.scan)
        log.info(f"Scan(s)          : {' '.join(subset.scan.names)}")
        mytaql.append("||".join([f'SCAN_NUMBER=={n}' for n in subset.scan.numbers]))
    else:
        subset.scan = ms.scan
        log.info('Scan(s)          : all')
    if options.iter_scan:
        group_cols.append('SCAN_NUMBER')

    if chanslice == slice(None):
        log.info('Channels         : all')
    else:
        log.info(f"Channels         : {':'.join(str(x) if x is not None else '' for x in chanslice_spec)}")

    mytaql = ' && '.join([f"({t})" for t in mytaql]) if mytaql else ''

    options.corr = options.corr.upper()
    if options.corr == "ALL":
        subset.corr = ms.corr
    else:
        # recognize shortcut when it's just Stokes indices, convert to list
        if re.fullmatch(r"[IQUV]+", options.corr):
            options.corr = ",".join(options.corr)
        subset.corr = ms.all_corr.get_subset(options.corr)
    log.info(f"Corr/Stokes      : {' '.join(subset.corr.names)}")

    blank()

    # figure out list of plots to make
    all_plots = []

    # This will be True if any of the specified axes change with correlation
    have_corr_dependence = False

    # now go create definitions
    for xaxis, yaxis, default_column, caxis, aaxis, ared, xmin, xmax, ymin, ymax, cmin, cmax, cnum in \
            zip(xaxes, yaxes, columns, caxes, aaxes, areds, xmins, xmaxs, ymins, ymaxs, cmins, cmaxs, cnums):
        # get axis specs
        xspecs = DataAxis.parse_datum_spec(xaxis, default_column, ms=ms)
        yspecs = DataAxis.parse_datum_spec(yaxis, default_column, ms=ms)
        aspecs = DataAxis.parse_datum_spec(aaxis, default_column, ms=ms) if aaxis else [None] * 4
        cspecs = DataAxis.parse_datum_spec(caxis, default_column, ms=ms) if caxis else [None] * 4
        # parse axis specifications
        xfunction, xcolumn, xcorr, xitercorr = xspecs
        yfunction, ycolumn, ycorr, yitercorr = yspecs
        afunction, acolumn, acorr, aitercorr = aspecs
        cfunction, ccolumn, ccorr, citercorr = cspecs
        # does anything here depend on correlation?
        datum_itercorr = (xitercorr or yitercorr or aitercorr or citercorr)
        if datum_itercorr:
            have_corr_dependence = True
        if "FLAG" in (xcolumn, ycolumn, acolumn, ccolumn) or "FLAG_ROW" in (xcolumn, ycolumn, acolumn, ccolumn):
            if not options.noflags:
                log.info(": plotting a flag column implies that flagged data will not be masked")
                options.noflags = True

        # do we iterate over correlations/Stokes to make separate plots now?
        if datum_itercorr and options.iter_corr:
            corr_list = subset.corr.numbers
        else:
            corr_list = [None]

        def describe_corr(corrvalue):
            """Returns list of correlation labels corresponding to this corr setting"""
            if corrvalue is None:
                return subset.corr.names
            elif corrvalue is False:
                return []
            else:
                return [ms.all_corr.names[corrvalue]]

        for corr in corr_list:
            plot_xcorr = corr if xcorr is None else xcorr  # False if no corr in datum, None if all, else set to iterant or to fixed value
            plot_ycorr = corr if ycorr is None else ycorr
            plot_acorr = corr if acorr is None else acorr
            plot_ccorr = corr if ccorr is None else ccorr
            xdatum = DataAxis.register(xfunction, xcolumn, plot_xcorr, ms=ms, minmax=(xmin, xmax), subset=subset)
            ydatum = DataAxis.register(yfunction, ycolumn, plot_ycorr, ms=ms, minmax=(ymin, ymax),  subset=subset)
            adatum = afunction and DataAxis.register(afunction, acolumn, plot_acorr, ms=ms,  subset=subset)
            cdatum = cfunction and DataAxis.register(cfunction, ccolumn, plot_ccorr, ms=ms,
                                                     minmax=(cmin, cmax), ncol=cnum, subset=subset)

            # figure out plot properties -- basically construct a descriptive name and label
            # looks complicated, but we're just trying to figure out what to put in the plot title...
            props = dict()
            titles = []
            labels = []
            # start with column and correlation(s)
            if ycolumn and not ydatum.mapper.column:   # only put column if not fixed by mapper
                titles.append(ycolumn)
                labels.append(col_to_label(ycolumn))
            titles += describe_corr(plot_ycorr)
            labels += describe_corr(plot_ycorr)
            titles += [ydatum.mapper.fullname, "vs"]
            if ydatum.function:
                labels.append(ydatum.function)
            # add x column/subset.corr, if different
            if xcolumn and (xcolumn != ycolumn or not xdatum.function) and not xdatum.mapper.column:
                titles.append(xcolumn)
                labels.append(col_to_label(xcolumn))
            if plot_xcorr != plot_ycorr:
                titles += describe_corr(plot_xcorr)
                labels += describe_corr(plot_xcorr)
            titles += [xdatum.mapper.fullname]
            if xdatum.function:
                labels.append(xdatum.function)
            props['title'] = " ".join(titles)
            props['label'] = "-".join(labels)
            # build up intensity label
            if afunction:
                titles, labels = [ared], [ared]
                if acolumn and (acolumn != xcolumn or acolumn != ycolumn) and adatum.mapper.column is None:
                    titles.append(acolumn)
                    labels.append(col_to_label(acolumn))
                if plot_acorr and (plot_acorr != plot_xcorr or plot_acorr != plot_ycorr):
                    titles += describe_corr(plot_acorr)
                    labels += describe_corr(plot_acorr)
                titles += [adatum.mapper.fullname]
                if adatum.function:
                    labels.append(adatum.function)
                props['alpha_title'] = " ".join(titles)
                props['alpha_label'] = "-".join(labels)
            else:
                props['alpha_title'] = props['alpha_label'] = ''
            # build up color-by label
            if cfunction:
                titles, labels = [], []
                if ccolumn and (ccolumn != xcolumn or ccolumn != ycolumn) and cdatum.mapper.column is None:
                    titles.append(ccolumn)
                    labels.append(col_to_label(ccolumn))
                if plot_ccorr and (plot_ccorr != plot_xcorr or plot_ccorr != plot_ycorr):
                    titles += describe_corr(plot_ccorr)
                    labels += describe_corr(plot_ccorr)
                if cdatum.mapper.fullname:
                    titles.append(cdatum.mapper.fullname)
                if cdatum.function:
                    labels.append(cdatum.function)
                if not cdatum.discretized_delta:
                    if not cdatum.discretized_labels or len(cdatum.discretized_labels) > cdatum.nlevels:
                        titles.append(f"(modulo {cdatum.nlevels})")
                props['color_title'] = " ".join(titles)
                props['color_label'] = "-".join(labels)
            else:
                props['color_title'] = props['color_label'] = ''

            all_plots.append((props, xdatum, ydatum, adatum, ared, cdatum))
            log.debug(f"adding plot for {props['title']}")

    join_corrs = not options.iter_corr and len(subset.corr) > 1 and have_corr_dependence

    log.info('                 : you have asked for {} plots employing {} unique datums'.format(len(all_plots),
                                                                                                len(DataAxis.all_axes)))
    if not len(all_plots):
        sys.exit(0)

    log.debug(f"taql is {mytaql}, group_cols is {group_cols}, join subset.corr is {join_corrs}")

    dataframes, np = \
        data_plots.get_plot_data(ms, group_cols, mytaql, ms.chan_freqs,
                                 chanslice=chanslice, subset=subset,
                                 noflags=options.noflags, noconj=options.noconj,
                                 iter_field=options.iter_field, iter_spw=options.iter_spw,
                                 iter_scan=options.iter_scan,
                                 join_corrs=join_corrs,
                                 row_chunk_size=options.row_chunk_size)

    log.info(f": rendering {len(dataframes)} dataframes with {np:.3g} points into {len(all_plots)} plot types")

    ## each dataframe is an instance of the axes being iterated over -- on top of that, we need to iterate over plot types

    # dictionary of substitutions for filename and title
    keys = {}
    keys['ms'] = os.path.basename(os.path.splitext(options.ms.rstrip("/"))[0])
    keys['timestamp'] = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    # dictionary of titles for these identifiers
    titles = dict(field="", field_num="field", scan="scan", spw="spw", antenna="ant", colortitle="coloured by")

    keys['field_num'] = subset.field.numbers if options.field != 'all' else ''
    keys['field'] = subset.field.names if options.field != 'all' else ''
    keys['scan'] = subset.scan.names if options.scan != 'all' else ''
    keys['ant'] = subset.ant.names if options.ant != 'all' else ''
    keys['spw'] = subset.spw.names if options.spw != 'all' else ''

    keys['suffix'] = suffix = options.suffix.format(**options.__dict__) if options.suffix else ''
    keys['_suffix'] = f".{suffix}" if suffix else ''

    def generate_string_from_keys(template, keys, listsep=" ", titlesep=" ", prefix=" "):
        """Converts list of keys into a string suitable for plot titles or filenames.
        listsep is used to separate list elements: " " for plot titles, "_" for filenames.
        titlesep is used to separate names from values: " " for plot titles, "_" for filenames
        prefix is used to prefix {_key} substitutions: ", " for titles, "-" for filenames
        """
        full_keys = keys.copy()
        # convert lists to strings, add Capitalized keys with titles
        for key, value in keys.items():
            capkey = key.title()
            if value == '':
                full_keys[capkey] = ''
            else:
                if type(value) is list:                               # e.g. scan=[1,2] becomes scan="1 2"
                    full_keys[key] = value = listsep.join(map(str, value))
                if key in titles:                                     # e.g. Scan="scan 1 2"
                    full_keys[capkey] = f"{titles[key]}{titlesep}{value}"
                else:
                    full_keys[capkey] = value
        # add _keys which supply prefixes
        full_keys.update({f"_{key}": (f"{prefix}{value}" if value else '') for key, value in full_keys.items()})
        # finally, format
        return template.format(**full_keys)

    jobs = []
    executor = None

    def render_single_plot(df, xdatum, ydatum, adatum, ared, cdatum, pngname, title, xlabel, ylabel):
        """Renders a single plot. Make this a function since we might call it in parallel"""
        log.info(f": rendering {pngname}")
        normalize = options.norm
        if normalize == "auto":
            normalize = "log" if cdatum is not None else "eq_hist"
        if options.profile:
            context = dask.diagnostics.ResourceProfiler
        else:
            context = nullcontext
        with context() as profiler:
            result = data_plots.create_plot(df, xdatum, ydatum, adatum, ared, cdatum,
                                      cmap=cmap, bmap=bmap, dmap=dmap, normalize=normalize,
                                      xlabel=xlabel, ylabel=ylabel, title=title, pngname=pngname,
                                      options=options)
        if result:
            log.info(f'                 : wrote {pngname}')
            if profiler is not None:
                profile_file = os.path.splitext(pngname)[0] + ".prof.html"
                dask.diagnostics.visualize(profiler, file_path=profile_file, show=False, save=True)
                log.info(f'                 : wrote profiler info to {profile_file}')


    for (fld, spw, scan, antenna), df in dataframes.items():
        # update keys to be substituted into title and filename
        if fld is not None:
            keys['field_num'] = fld
            keys['field'] = ms.field[fld]
        if spw is not None:
            keys['spw'] = spw
        if scan is not None:
            keys['scan'] = scan
        if antenna is not None:
            keys['ant'] = ms.all_antenna[antenna]

        # now loop over plot types
        for props, xdatum, ydatum, adatum, ared, cdatum in all_plots:
            keys.update(title=props['title'], label=props['label'],
                        alphatitle=props['alpha_title'], alphalabel=props['alpha_label'],
                        colortitle=props['color_title'], colorlabel=props['color_label'],
                        xname=xdatum.fullname, yname=ydatum.fullname,
                        xunit=xdatum.mapper.unit, yunit=ydatum.mapper.unit)

            pngname = generate_string_from_keys(options.pngname, keys, "_", "_", "-")
            title   = generate_string_from_keys(options.title, keys, " ", " ", ", ")
            xlabel  = generate_string_from_keys(options.xlabel, keys, " ", " ", ", ")
            ylabel  = generate_string_from_keys(options.ylabel, keys, " ", " ", ", ")

            if options.dir:
                pngname = os.path.join(options.dir, pngname)

            # make output directory, if needed
            dirname = os.path.dirname(pngname)
            if dirname and not os.path.exists(dirname):
                os.mkdir(dirname)
                log.info(f'                 : created output directory {dirname}')

            if options.num_parallel < 2 or len(all_plots) < 2:
                render_single_plot(df, xdatum, ydatum, adatum, ared, cdatum, pngname, title, xlabel, ylabel)
            else:
                from concurrent.futures import ThreadPoolExecutor
                executor = ThreadPoolExecutor(options.num_parallel)
                log.info(f'                 : submitting job for {pngname}')
                jobs.append(executor.submit(render_single_plot, df, xdatum, ydatum, adatum, ared, cdatum,
                                            pngname, title, xlabel, ylabel))

    # wait for jobs to finish
    if executor is not None:
        log.info(f'                 : waiting for {len(jobs)} jobs to complete')
        for job in jobs:
            job.result()

    clock_stop = time.time()
    elapsed = str(round((clock_stop-clock_start), 2))

    log.info('Total time       : %s seconds' % (elapsed))
    log.info('Finished')
    blank()
Exemplo n.º 7
0
def main(argv):

    # TODO: I want to inspect a different time profile method
    clock_start = time.time()

    # default # of CPUs

    # ---  dealing with input arguments ---
    parser, optimization_opts = cli()
    # various hidden performance-testing options
    data_mappers.add_options(optimization_opts)
    data_plots.add_options(optimization_opts)
    options = parser.parse_args(argv)
    if options.debug:
        shade_ms.log_console_handler.setLevel(logging.DEBUG)

    # pass options to shade_ms
    data_mappers.set_options(options)
    data_plots.set_options(options)

    # set colormaps
    cmap = data_plots.get_colormap(options.cmap)
    bmap = data_plots.get_colormap(options.bmap)
    dmap = data_plots.get_colormap(options.dmap)

    # figure our list of plots to make
    # define plot parameters
    [xaxes, yaxes, columns, caxes, aaxes, areds,
     xmins, xmaxs, ymins, ymaxs, amins, amaxs, cmins, cmaxs, cnums] = parse_plot_spec(parser,
                                                                                      options)


    # check markup arguments
    extra_markup = []
    for funcname, funcargs in options.markup or []:
        from matplotlib.axes import Axes
        if funcname not in dir(Axes):
            parser.error(f"unknown function given in --markup {funcname}")
        args = yaml.safe_load(funcargs)
        kwargs = []
        if type(args) not in {list, dict}:
            parser.error(f"invalid arguments to --markup {funcname} {funcargs}")
        if type(args) is list:
            if len(args) > 0 and type(args[-1]) is dict:
                kwargs = args[-1]
                args = args[:-1]
        else:
            kwargs = args
            args = []
        extra_markup.append((funcname, args, kwargs))

    # check vline and hline
    for attr in 'hline', 'vline':
        if getattr(options, attr, None):
            for spec in getattr(options, attr).split(","):
                match = re.match('(.*?)((--|-|-\\.|:)([a-zA-Z#].*)?)?$', spec)
                try:
                    coord = float(match.group(1))
                except ValueError:
                    coord = None
                    parser.error(f"invalid --{attr} setting '{spec}'")
                linestyle = match.group(3) or '-'
                extra_markup.append((f"ax{attr}", [coord], dict(ls=linestyle,
                                                                color=match.group(4) or "black")))

    # re-check that kwargs are valid
    for funcname, args, kwargs in extra_markup:
        log.info(f"markup: {funcname} *{args} **{kwargs})")
        if not all([re.match('^\w+$', kw) for kw in kwargs.keys()]):
            log.error("the above is not a valid markup specification, please fix")
            sys.exit(1)

    # check chan slice
    try:
        chanslice, chanslice_spec = parse_slice_spec(options.chan)
    except ValueError:
        # parser.error(f"invalid selection --{'chan'} {options.chan}")
        parser.error(f"invalid selection --chan {options.chan}")

    # issue warning if only a single antenna is specified
    num_ants_warning = False
    if options.ant_num:
        specs = options.ant_num.split(',')
        if len(specs) == 1 and re.fullmatch(r"\d+", specs[0]):
            ant_spec = "ant{}".format(int(specs[0])-1)
            num_ants_warning = True
        # check ant selection before processing
        for spec in options.ant_num.split(","):
            try:
                antslice, antslice_spec = parse_slice_spec(spec)
            except ValueError:
                parser.error(f"invalid selection --{'ant-num'} {options.ant_num}")
    elif not 'all' in options.ant:
        specs = options.ant.split(',')
        if len(specs) == 1:
            ant_spec = "{}".format(specs[0])
            num_ants_warning = True
    if num_ants_warning:
        msg = f"Suggested usage: '--baseline {ant_spec}-*' to select all baselines to {ant_spec}."
        warnings.warn(msg, category=UserWarning)

    log.info(" ".join(sys.argv))

    separator()
    # ---  dealing with input arguments ---

    ms = MSInfo(options.ms, log=log)

    separator()
    log.info(": Data selected for plotting:")

    group_cols = ['FIELD_ID', 'DATA_DESC_ID']

    # --- building SQL query --
    mytaql = []

    class Subset(object):
        pass
    subset = Subset()

    if not 'all' in options.ant or options.ant_num:
        if options.ant_num:
            ant_subset = set()
            for spec in options.ant_num.split(","):
                if re.fullmatch(r"\d+", spec):
                    ant_subset.add(int(spec))
                else:
                    antslice, antslice_spec = parse_slice_spec(spec)
                    ant_subset.update(ms.all_antenna.numbers[antslice])
            subset.ant = ms.antenna.get_subset(sorted(ant_subset))
        else:
            subset.ant = ms.antenna.get_subset(options.ant)
        log.info(f"Antenna name(s)  : {' '.join(subset.ant.names)}")
        antnum_set = f"[{','.join(map(str, subset.ant.numbers))}]"
        mytaql.append(f"ANTENNA1 IN {antnum_set} && ANTENNA2 IN {antnum_set}")
    else:
        subset.ant = ms.antenna
        log.info('Antenna(s)       : all')

    if options.baseline == 'all':
        log.info('Baseline(s)      : all')
        subset.baseline = ms.baseline
    elif options.baseline == 'noautocorr':
        log.info('Baseline(s)      : all except autocorrelations')
        subset.baseline = ms.all_baseline.get_subset([i for i in ms.baseline.numbers if ms.baseline_lengths[i] !=0 ])
        mytaql.append("ANTENNA1!=ANTENNA2")
    elif options.baseline == 'autocorr':
        log.info('Baseline(s)      : autocorrelations')
        subset.baseline = ms.all_baseline.get_subset([i for i in ms.baseline.numbers if ms.baseline_lengths[i] == 0])
        mytaql.append("ANTENNA1==ANTENNA2")
    else:
        bls = set()
        a1a2 = set()
        for blspec in options.baseline.split(","):
            match = re.fullmatch(r"(\w+)-(\w*|[*])", blspec)
            ant1 = match and ms.antenna[match.group(1)]
            ant2 = match and (ms.antenna[match.group(2)] if match.group(2) not in ['', '*'] else '*')
            if ant1 is None or ant2 is None:
                raise ValueError(f"invalid baseline '{blspec}'")
            if ant2 == '*':
                ant2set = ms.all_antenna.numbers
            else:
                ant2set = [ant2]
            # loop
            for ant2 in ant2set:
                a1, a2 = min(ant1, ant2), max(ant1, ant2)
                a1a2.add((a1, a2))
                bls.add(ms.baseline_number(a1, a2))
        # group_cols.append('ANTENNA1')
        subset.baseline = ms.all_baseline.get_subset(sorted(bls))
        log.info(f"Baseline(s)      : {' '.join(subset.baseline.names)}")
        mytaql.append("||".join([f'(ANTENNA1=={ant1}&&ANTENNA2=={ant2})||(ANTENNA1=={ant2}&&ANTENNA2=={ant1})'
                                 for ant1, ant2 in a1a2]))

    if options.field != 'all':
        subset.field = ms.field.get_subset(options.field)
        log.info(f"Field(s)         : {' '.join(subset.field.names)}")
        # here for now, workaround for https://github.com/ska-sa/dask-ms/issues/100, should be inside if clause
        mytaql.append("||".join([f'FIELD_ID=={fld}' for fld in subset.field.numbers]))
    else:
        subset.field = ms.field
        log.info('Field(s)         : all')

    if options.spw != 'all':
        subset.spw = ms.spw.get_subset(options.spw)
        log.info(f"SPW(s)           : {' '.join(subset.spw.names)}")
        mytaql.append("||".join([f'DATA_DESC_ID=={ddid}' for ddid in subset.spw.numbers]))
    else:
        subset.spw = ms.spw
        log.info(f'SPW(s)           : all')

    if options.scan != 'all':
        subset.scan = ms.scan.get_subset(options.scan, allow_numeric_indices=False)
        log.info(f"Scan(s)          : {' '.join(subset.scan.names)}")
        mytaql.append("||".join([f'SCAN_NUMBER=={n}' for n in subset.scan.numbers]))
    else:
        subset.scan = ms.scan
        log.info('Scan(s)          : all')
    if options.iter_scan:
        group_cols.append('SCAN_NUMBER')

    if chanslice == slice(None):
        log.info('Channels         : all')
    else:
        log.info(f"Channels         : {':'.join(str(x) if x is not None else '' for x in chanslice_spec)}")

    mytaql = ' && '.join([f"({t})" for t in mytaql]) if mytaql else ''
  # --- building SQL query --

    if options.corr == "ALL":
        subset.corr = ms.corr
    else:
        # recognize shortcut when it's just Stokes indices, convert to list
        if re.fullmatch(r"[IQUV]+", options.corr):
            options.corr = ",".join(options.corr)
        subset.corr = ms.all_corr.get_subset(options.corr)
    log.info(f"Corr/Stokes      : {' '.join(subset.corr.names)}")

    separator()

    # check minmax cache
    msbase = os.path.splitext(os.path.basename(options.ms))[0]
    cache_file = options.lim_file.format(ms=msbase)
    if options.dir and not "/" in cache_file:
        cache_file = os.path.join(options.dir, cache_file)

    # try to load the minmax cache file
    if not os.path.exists(cache_file):
        minmax_cache = {}
    else:
        log.info(f"loading minmax cache from {cache_file}")
        try:
            minmax_cache = json.load(open(cache_file, "rt"))
            if type(minmax_cache) is not dict:
                raise TypeError("cache content is not a dict")
        except Exception as exc:
            log.error(f"error reading cache file: {exc}. Minmax cache will be reset.")
            minmax_cache = {}

    # figure out list of plots to make
    all_plots = []

    # This will be True if any of the specified axes change with correlation
    have_corr_dependence = False

    # now go create definitions
    for xaxis, yaxis, default_column, caxis, aaxis, ared, xmin, xmax, ymin, ymax, amin, amax, cmin, cmax, cnum in \
        zip(xaxes, yaxes, columns, caxes, aaxes, areds, xmins, xmaxs, ymins, ymaxs, amins, amaxs, cmins, cmaxs, cnums):
        # get axis specs
        xspecs = DataAxis.parse_datum_spec(xaxis, default_column, ms=ms)
        yspecs = DataAxis.parse_datum_spec(yaxis, default_column, ms=ms)
        aspecs = DataAxis.parse_datum_spec(aaxis, default_column, ms=ms) if aaxis else [None] * 4
        cspecs = DataAxis.parse_datum_spec(caxis, default_column, ms=ms) if caxis else [None] * 4
        # parse axis specifications
        xfunction, xcolumn, xcorr, xitercorr = xspecs
        yfunction, ycolumn, ycorr, yitercorr = yspecs
        afunction, acolumn, acorr, aitercorr = aspecs
        cfunction, ccolumn, ccorr, citercorr = cspecs
        # does anything here depend on correlation?
        datum_itercorr = (xitercorr or yitercorr or aitercorr or citercorr)
        if datum_itercorr:
            have_corr_dependence = True
        if "FLAG" in (xcolumn, ycolumn, acolumn, ccolumn) or "FLAG_ROW" in (xcolumn, ycolumn, acolumn, ccolumn):
            if not options.noflags:
                log.info(": plotting a flag column implies that flagged data will not be masked")
                options.noflags = True

        # do we iterate over correlations/Stokes to make separate plots now?
        if datum_itercorr and options.iter_corr:
            corr_list = subset.corr.numbers
        else:
            corr_list = [None]

        def describe_corr(corrvalue):
            """Returns list of correlation labels corresponding to this corr setting"""
            if corrvalue is None:
                return subset.corr.names
            elif corrvalue is False:
                return []
            else:
                return [ms.all_corr.names[corrvalue]]

        for corr in corr_list:
            plot_xcorr = corr if xcorr is None else xcorr  # False if no corr in datum, None if all, else set to iterant or to fixed value
            plot_ycorr = corr if ycorr is None else ycorr
            plot_acorr = corr if acorr is None else acorr
            plot_ccorr = corr if ccorr is None else ccorr
            xdatum = DataAxis.register(xfunction, xcolumn, plot_xcorr, ms=ms, minmax=(xmin, xmax), subset=subset,
                                       minmax_cache=minmax_cache if options.xlim_load else None)
            ydatum = DataAxis.register(yfunction, ycolumn, plot_ycorr, ms=ms, minmax=(ymin, ymax), subset=subset,
                                       minmax_cache=minmax_cache if options.ylim_load else None)
            adatum = afunction and DataAxis.register(afunction, acolumn, plot_acorr, ms=ms,
                                                     minmax=(amin, amax), subset=subset)
            cdatum = cfunction and DataAxis.register(cfunction, ccolumn, plot_ccorr, ms=ms,
                                                     minmax=(cmin, cmax), ncol=cnum, subset=subset,
                                                     minmax_cache=minmax_cache if options.clim_load else None)

            # figure out plot properties -- basically construct a descriptive name and label
            # looks complicated, but we're just trying to figure out what to put in the plot title...
            props = dict()
            titles = []
            labels = []
            # start with column and correlation(s)
            if ycolumn and not ydatum.mapper.column:   # only put column if not fixed by mapper
                titles.append(ycolumn)
                labels.append(col_to_label(ycolumn))
            titles += describe_corr(plot_ycorr)
            labels += describe_corr(plot_ycorr)
            if ydatum.mapper.fullname:
                titles += [ydatum.mapper.fullname]
            titles += ["vs"]
            if ydatum.function:
                labels.append(ydatum.function)
            # add x column/subset.corr, if different
            if xcolumn and (xcolumn != ycolumn or not xdatum.function) and not xdatum.mapper.column:
                titles.append(xcolumn)
                labels.append(col_to_label(xcolumn))
            if plot_xcorr is not plot_ycorr:
                titles += describe_corr(plot_xcorr)
                labels += describe_corr(plot_xcorr)
            if xdatum.mapper.fullname:
                titles += [xdatum.mapper.fullname]
            if xdatum.function:
                labels.append(xdatum.function)
            props['title'] = " ".join(titles)
            props['label'] = "-".join(labels)
            # build up intensity label
            if afunction:
                titles, labels = [ared], [ared]
                if acolumn and (acolumn != xcolumn or acolumn != ycolumn) and adatum.mapper.column is None:
                    titles.append(acolumn)
                    labels.append(col_to_label(acolumn))
                if plot_acorr is not None and plot_acorr is not False and \
                        (plot_acorr is not plot_xcorr or plot_acorr is not plot_ycorr):
                    titles += describe_corr(plot_acorr)
                    labels += describe_corr(plot_acorr)
                titles += [adatum.mapper.fullname]
                if adatum.function:
                    labels.append(adatum.function)
                props['alpha_title'] = " ".join(titles)
                props['alpha_label'] = "-".join(labels)
            else:
                props['alpha_title'] = props['alpha_label'] = ''
            # build up color-by label
            if cfunction:
                titles, labels = [], []
                if ccolumn and (ccolumn != xcolumn or ccolumn != ycolumn) and cdatum.mapper.column is None:
                    titles.append(ccolumn)
                    labels.append(col_to_label(ccolumn))
                if plot_ccorr is not None and plot_ccorr is not False and \
                        (plot_ccorr is not plot_xcorr or plot_ccorr is not plot_ycorr):
                    titles += describe_corr(plot_ccorr)
                    labels += describe_corr(plot_ccorr)
                if cdatum.mapper.fullname:
                    titles.append(cdatum.mapper.fullname)
                if cdatum.function:
                    labels.append(cdatum.function)
                props['color_title'] = " ".join(titles)
                props['color_label'] = "-".join(labels)
                props['color_modulo'] = f"(into {cdatum.nlevels} colours)"
            else:
                props['color_title'] = props['color_label'] = ''

            all_plots.append((props, xdatum, ydatum, adatum, ared, cdatum))
            log.debug(f"adding plot for {props['title']}")

    # reset minmax cache if requested
    if options.lim_save_reset:
        minmax_cache = {}

    join_corrs = not options.iter_corr and len(subset.corr) > 1 and have_corr_dependence

    log.info('                 : you have asked for {} plots employing {} unique datums'.format(len(all_plots),
                                                                                                len(DataAxis.all_axes)))
    if not len(all_plots):
        sys.exit(0)

    log.debug(f"taql is {mytaql}, group_cols is {group_cols}, join subset.corr is {join_corrs}")

    dataframes, index_subsets, np = \
        data_plots.get_plot_data(ms, group_cols, mytaql, ms.chan_freqs,
                                 chanslice=chanslice, subset=subset,
                                 noflags=options.noflags, noconj=options.noconj,
                                 iter_field=options.iter_field, iter_spw=options.iter_spw,
                                 iter_scan=options.iter_scan, iter_ant=options.iter_ant,
                                 iter_baseline=options.iter_baseline,
                                 join_corrs=join_corrs,
                                 row_chunk_size=options.row_chunk_size)
    if len(dataframes) < 1:
        log.warn("No data for selection subset")
    log.info(f": rendering {len(dataframes)} dataframes with {np:.3g} points into {len(all_plots)} plot types")

    ## each dataframe is an instance of the axes being iterated over -- on top of that, we need to iterate over plot types

    # dictionary of substitutions for filename and title
    keys = {}
    keys['ms'] = os.path.basename(os.path.splitext(options.ms.rstrip("/"))[0])
    keys['timestamp'] = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    # dictionary of titles for these identifiers
    titles = dict(field="", field_num="field", scan="scan", spw="spw", antenna="ant", colortitle="coloured by")

    keys['field_num'] = subset.field.numbers if options.field != 'all' else ''
    keys['field'] = subset.field.names if options.field != 'all' else ''
    keys['scan'] = subset.scan.names if options.scan != 'all' else ''
    keys['ant'] = subset.ant.names if options.ant != 'all' else ''  ## TODO: also handle ant-num settings
    keys['spw'] = subset.spw.names if options.spw != 'all' else ''
    keys['baseline'] = None

    keys['suffix'] = suffix = options.suffix.format(**options.__dict__) if options.suffix else ''
    keys['_suffix'] = f".{suffix}" if suffix else ''

    def generate_string_from_keys(template, keys, listsep=" ", titlesep=" ", prefix=" "):
        """Converts list of keys into a string suitable for plot titles or filenames.
        listsep is used to separate list elements: " " for plot titles, "_" for filenames.
        titlesep is used to separate names from values: " " for plot titles, "_" for filenames
        prefix is used to prefix {_key} substitutions: ", " for titles, "-" for filenames
        """
        full_keys = keys.copy()
        # convert lists to strings, add Capitalized keys with titles
        for key, value in keys.items():
            capkey = key.title()
            if value == '':
                full_keys[capkey] = ''
            else:
                if type(value) is list:                               # e.g. scan=[1,2] becomes scan="1 2"
                    full_keys[key] = value = listsep.join(map(str, value))
                if key in titles:                                     # e.g. Scan="scan 1 2"
                    full_keys[capkey] = f"{titles[key]}{titlesep}{value}"
                else:
                    full_keys[capkey] = value
        # add _keys which supply prefixes
        full_keys.update({f"_{key}": (f"{prefix}{value}" if value else '') for key, value in full_keys.items()})
        # finally, format
        return template.format(**full_keys)

    jobs = []
    executor = None

    def render_single_plot(df, subset, xdatum, ydatum, adatum, ared, cdatum, pngname, title, xlabel, ylabel):
        """Renders a single plot. Make this a function since we might call it in parallel"""
        log.info(f": rendering {pngname}")
        normalize = options.norm
        if normalize == "auto":
            normalize = "log" if cdatum is not None else ("eq_hist" if adatum is None else 'linear')
        if options.profile:
            context = dask.diagnostics.ResourceProfiler
        else:
            context = nullcontext
        with context() as profiler:
            result = data_plots.create_plot(df, subset, xdatum, ydatum, adatum, ared, cdatum,
                                      cmap=cmap, bmap=bmap, dmap=dmap, normalize=normalize,
                                      min_alpha=options.min_alpha,
                                      saturate_alpha=options.saturate_alpha,
                                      saturate_percentile=options.saturate_perc,
                                      xlabel=xlabel, ylabel=ylabel, title=title, pngname=pngname,
                                      extra_markup=extra_markup,
                                      minmax_cache=minmax_cache,
                                      options=options)
        if result:
            log.info(f'                 : wrote {pngname}')
            if profiler is not None:
                profile_file = os.path.splitext(pngname)[0] + ".prof.html"
                dask.diagnostics.visualize(profiler, file_path=profile_file, show=False, save=True)
                log.info(f'                 : wrote profiler info to {profile_file}')


    for (fld, spw, scan, antenna_or_baseline), df in dataframes.items():
        subset = index_subsets[fld, spw, scan, antenna_or_baseline]
        # update keys to be substituted into title and filename
        if fld is not None:
            keys['field_num'] = fld
            keys['field'] = ms.field[fld]
        if spw is not None:
            keys['spw'] = spw
        if scan is not None:
            keys['scan'] = scan
        if antenna_or_baseline is not None:
            if options.iter_ant:
                keys['ant'] = ms.all_antenna[antenna_or_baseline]
            elif options.iter_baseline:
                keys['baseline'] = ms.all_baseline[antenna_or_baseline]

        # now loop over plot types
        for props, xdatum, ydatum, adatum, ared, cdatum in all_plots:
            keys.update(title=props['title'], label=props['label'],
                        alphatitle=props['alpha_title'], alphalabel=props['alpha_label'],
                        colortitle=props['color_title'],
                        colorlabel=props['color_label'],
                        xname=xdatum.fullname, yname=ydatum.fullname,
                        xunit=xdatum.mapper.unit, yunit=ydatum.mapper.unit)
            if cdatum is not None and cdatum.is_discrete and not cdatum.discretized_labels:
                keys['colortitle'] += ' '+props['color_modulo']

            pngname = generate_string_from_keys(options.pngname, keys, "_", "_", "-")
            title   = generate_string_from_keys(options.title, keys, " ", " ", ", ")
            xlabel  = generate_string_from_keys(options.xlabel, keys, " ", " ", ", ")
            ylabel  = generate_string_from_keys(options.ylabel, keys, " ", " ", ", ")

            if options.dir:
                pngname = os.path.join(options.dir, pngname)

            # make output directory, if needed
            dirname = os.path.dirname(pngname)
            if dirname and not os.path.exists(dirname):
                os.mkdir(dirname)
                log.info(f'                 : created output directory {dirname}')

            if options.num_parallel < 2 or len(all_plots) < 2:
                render_single_plot(df, subset, xdatum, ydatum, adatum, ared, cdatum, pngname, title, xlabel, ylabel)
            else:
                from concurrent.futures import ThreadPoolExecutor
                executor = ThreadPoolExecutor(options.num_parallel)
                log.info(f'                 : submitting job for {pngname}')
                jobs.append(executor.submit(render_single_plot, df, subset, xdatum, ydatum, adatum, ared, cdatum,
                                            pngname, title, xlabel, ylabel))

    # wait for jobs to finish
    if executor is not None:
        log.info(f'                 : waiting for {len(jobs)} jobs to complete')
        for job in jobs:
            job.result()

    clock_stop = time.time()
    elapsed = str(round((clock_stop-clock_start), 2))

    log.info('Total time       : %s seconds' % (elapsed))

    if minmax_cache and options.lim_save:
        # ensure floats, because in64s and such cause errors
        minmax_cache = {axis: list(map(float, minmax)) for axis, minmax in minmax_cache.items()}

        with open(cache_file, "wt") as file:
            json.dump(minmax_cache, file, sort_keys=True, indent=4, separators=(',', ': '))
        log.info(f"Saved minmax cache to {cache_file} (disable with --no-lim-save)")

    log.info('Finished')
    separator()
Exemplo n.º 8
0
    def __init__(self,
                 column,
                 function,
                 corr,
                 ms,
                 minmax=None,
                 ncol=None,
                 label=None,
                 subset=None):
        """See register() class method above. Not called directly."""
        self.name = ":".join([
            str(x) for x in (function, column, corr, minmax, ncol)
            if x is not None
        ])
        self.ms = ms
        self.function = function  # function to apply to column (see list of DataMappers below)
        self.corr = corr if corr != "all" else None
        self.nlevels = ncol
        self.minmax = vmin, vmax = tuple(minmax) if minmax is not None else (
            None, None)
        self.label = label
        self._corr_reduce = None
        self.discretized_labels = None  # filled for corrs and fields and so

        # set up discretized continuous axis
        if self.nlevels and vmin is not None and vmax is not None:
            self.discretized_delta = delta = (vmax - vmin) / self.nlevels
            self.discretized_bin_centers = numpy.arange(
                vmin + delta / 2, vmax, delta)
        else:
            self.discretized_delta = self.discretized_bin_centers = None

        self.mapper = data_mappers[function]

        # columns with labels?
        if function == 'CORR' or function == 'STOKES':
            corrset = set(subset.corr.names)
            corrset1 = corrset - set("IQUV")
            if corrset1 == corrset:
                name = "Correlation"
            elif not corrset1:
                name = "Stokes"
            else:
                name = "Correlation or Stokes"
            # special case of "corr" if corr is fixed: return constant value fixed here
            # When corr is fixed, we're creating a mapper for a know correlation. When corr is not fixed,
            # we're creating one for a mapper that will iterate over correlations
            if corr is not None:
                self.mapper = DataMapper(name,
                                         "",
                                         column=False,
                                         axis=-1,
                                         mapper=lambda x: corr)
            self.discretized_labels = subset.corr.names
        elif column == "FIELD_ID":
            self.discretized_labels = [
                name for name in ms.field.names if name in subset.field
            ]
        elif column == "ANTENNA1" or column == "ANTENNA2":
            self.discretized_labels = [
                name for name in ms.all_antenna.names if name in subset.ant
            ]
        elif column == "FLAG" or column == "FLAG_ROW":
            self.discretized_labels = ["F", "T"]

        # axis name
        self.fullname = self.mapper.fullname or column or ''

        # if labels were set up, adjust nlevels (but only down, never up)
        if self.discretized_labels is not None and self.nlevels is not None:
            self.nlevels = min(len(self.discretized_labels), self.nlevels)

        if self.function == "_":
            self.function = ""
        self.conjugate = self.mapper.conjugate
        self.timefreq_axis = self.mapper.axis

        # setup columns
        self._ufunc = None
        self.columns = ()

        # does the mapper have no column (i.e. frequency)?
        if self.mapper.column is False:
            log.info(
                f'axis: {function}, range {self.minmax}, discretization {self.nlevels}'
            )
        # does the mapper have a fixed column? This better be consistent
        elif self.mapper.column is not None:
            # if a mapper (such as "uv") implies a fixed column name, make sure it's consistent with what the user said
            if column and self.mapper.column != column:
                raise ValueError(
                    f"'{function}' not applicable with column {column}")
            self.columns = (column, )
            log.info(
                f'axis: {function}({column}), range {self.minmax}, discretization {self.nlevels}'
            )
        # else arbitrary column
        else:
            log.info(
                f'axis: {function}({column}), corr {self.corr}, range {self.minmax}, discretization {self.nlevels}'
            )
            # check for column arithmetic
            match = re.fullmatch(r"(\w+)([*/+-])(\w+)", column)
            if match:
                self.columns = (match.group(1), match.group(3))
                # look up dask ufunc corresponding to arithmetic op
                self._ufunc = {
                    '+': da.add,
                    '*': da.multiply,
                    '-': da.subtract,
                    '/': da.divide
                }[match.group(2)]
            else:
                self.columns = (column, )
Exemplo n.º 9
0
def get_plot_data(msinfo, group_cols, mytaql, chan_freqs,
                  chanslice, subset,
                  noflags, noconj,
                  iter_field, iter_spw, iter_scan, iter_ant, iter_baseline,
                  join_corrs=False,
                  row_chunk_size=100000):

    ms_cols = {'ANTENNA1', 'ANTENNA2'}
    ms_cols.update(msinfo.indexing_columns.keys())
    if not noflags:
        ms_cols.update({'FLAG', 'FLAG_ROW'})
    # get visibility columns
    for axis in DataAxis.all_axes.values():
        ms_cols.update(axis.columns)

    total_num_points = 0  # total number of points to plot

    # output dataframes, indexed by (field, spw, scan, antenna_or_baseline)
    # If any of these axes is not being iterated over, then the index at that position is None
    output_dataframes = OrderedDict()

    # number of rows per each dataframe
    output_rows = OrderedDict()

    # output subsets of indexing columns, indexed by same tuple
    output_subsets = OrderedDict()

    if iter_ant:
        antenna_subsets = zip(subset.ant.numbers, subset.ant.names)
    else:
        antenna_subsets = [(None, None)]
    taql = mytaql

    for antenna, antname in antenna_subsets:
        if antenna is not None:
            taql = f"({mytaql})&&(ANTENNA1=={antenna} || ANTENNA2=={antenna})" if mytaql else \
                    f"(ANTENNA1=={antenna} || ANTENNA2=={antenna})"
        # add baselines to group columns
        if iter_baseline:
            group_cols = list(group_cols) + ["ANTENNA1", "ANTENNA2"]

        # get MS data
        msdata = daskms.xds_from_ms(msinfo.msname, columns=list(ms_cols), group_cols=group_cols, taql_where=taql,
                                    chunks=dict(row=row_chunk_size))
        nrow = sum([len(group.row) for group in msdata])
        if not nrow:
            continue

        if antenna is not None:
            log.info(f': Indexing sub-MS (antenna {antname}) and building dataframes ({nrow} rows, chunk size is {row_chunk_size})')
        else:
            log.info(f': Indexing MS and building dataframes ({nrow} rows, chunk size is {row_chunk_size})')

        # iterate over groups
        for group in msdata:
            if not len(group.row):
                continue
            ddid     =  group.DATA_DESC_ID  # always present
            fld      =  group.FIELD_ID      # always present
            if fld not in subset.field or ddid not in subset.spw:
                log.debug(f"field {fld} ddid {ddid} not in selection, skipping")
                continue
            scan = getattr(group, 'SCAN_NUMBER', None)  # will be present if iterating over scans
            if iter_baseline:
                ant1    = getattr(group, 'ANTENNA1', None)   # will be present if iterating over baselines
                ant2    = getattr(group, 'ANTENNA2', None)   # will be present if iterating over baselines
                baseline = msinfo.baseline_number(ant1, ant2)
            else:
                baseline = None

            # Make frame key -- data subset corresponds to this frame
            dataframe_key = (fld if iter_field else None,
                             ddid if iter_spw else None,
                             scan if iter_scan else None,
                             antenna if antenna is not None else baseline)

            # update subsets of MS indexing columns that we've seen for this dataframe
            output_subset1 = output_subsets.setdefault(dataframe_key,
                                                {column:set() for column in msinfo.indexing_columns.keys()})
            for column, _ in msinfo.indexing_columns.items():
                value = getattr(group, column)
                if np.isscalar(value):
                    output_subset1[column].add(value)
                else:
                    output_subset1[column].update(value.compute().data)

            # number of rows in dataframe
            nrows0 = output_rows.setdefault(dataframe_key, 0)

            # always read flags -- easier that way
            flag = group.FLAG if not noflags else None
            flag_row = group.FLAG_ROW if not noflags else None

            a1 = da.minimum(group.ANTENNA1.data, group.ANTENNA2.data)
            a2 = da.maximum(group.ANTENNA1.data, group.ANTENNA2.data)
            baselines = msinfo.baseline_number(a1, a2)

            freqs = chan_freqs[ddid]
            chans = xarray.DataArray(range(len(freqs)), dims=("chan",))
            wavel = freq_to_wavel(freqs)
            extras = dict(chans=chans, freqs=freqs, wavel=wavel, rows=group.row, baselines=baselines)

            nchan = len(group.chan)
            if flag is not None:
                flag = flag[dict(chan=chanslice)]
                nchan = flag.shape[1]
            shape = (len(group.row), nchan)

            arrays = OrderedDict()
            shapes = OrderedDict()
            ddf = None
            num_points = 0  # counts number of new points generated


            for corr in subset.corr.numbers:
                # make dictionary of extra values for DataMappers
                extras['corr'] = corr
                # loop over datums to be computed
                for axis in DataAxis.all_axes.values():
                    value = arrays.get(axis.label)
                    # a datum was already computed?
                    if value is not None:
                        # if not joining correlations, then that's the only one we'll need, so continue
                        if not join_corrs:
                            continue
                        # joining correlations, and datum has a correlation dependence: compute another one
                        if axis.corr is None:
                            value = None
                    if value is None:
                        value = axis.get_value(group, corr, extras, flag=flag, flag_row=flag_row, chanslice=chanslice)
                        # print(axis.label, value.compute().min(), value.compute().max())
                        num_points = max(num_points, value.size)
                        if value.ndim == 0:
                            shapes[axis.label] = ()
                        elif value.ndim == 1:
                            timefreq_axis = axis.mapper.axis or 0
                            assert value.shape[0] == shape[timefreq_axis], \
                                   f"{axis.mapper.fullname}: size {value.shape[0]}, expected {shape[timefreq_axis]}"
                            shapes[axis.label] = ("row",) if timefreq_axis == 0 else ("chan",)
                        # else 2D value better match expected shape
                        else:
                            assert value.shape == shape, f"{axis.mapper.fullname}: shape {value.shape}, expected {shape}"
                            shapes[axis.label] = ("row", "chan")
                        arrays[axis.label] = value
                # any new data generated for this correlation? Make dataframe
                if num_points:
                    total_num_points += num_points
                    args = (v for pair in ((array, shapes[key]) for key, array in arrays.items()) for v in pair)
                    df1 = dataframe_factory(("row", "chan"), *args, columns=arrays.keys())
                    # if any axis needs to be conjugated, double up all of them
                    if not noconj and any([axis.conjugate for axis in DataAxis.all_axes.values()]):
                        arr_shape = [(-arrays[axis.label] if axis.conjugate else arrays[axis.label], shapes[axis.label])
                                                for axis in DataAxis.all_axes.values()]
                        args = (v for pair in arr_shape  for v in pair)
                        df2 = dataframe_factory(("row", "chan"), *args, columns=arrays.keys())
                        df1 = dask_df.concat([df1, df2], axis=0)
                    ddf = dask_df.concat([ddf, df1], axis=0) if ddf is not None else df1

            # do we already have a frame for this key
            ddf0 = output_dataframes.get(dataframe_key)

            if ddf0 is None:
                log.debug(f"first frame for {dataframe_key}")
                output_dataframes[dataframe_key] = ddf
            else:
                log.debug(f"appending to frame for {dataframe_key}")
                output_dataframes[dataframe_key] = dask_df.concat([ddf0, ddf], axis=0)

    # convert discrete axes into categoricals
    if data_mappers.USE_COUNT_CAT:
        categorical_axes = [axis.label for axis in DataAxis.all_axes.values() if axis.nlevels]
        if categorical_axes:
            log.info(": counting colours")
            for key, ddf in list(output_dataframes.items()):
                output_dataframes[key] = ddf.categorize(categorical_axes)

    # print("===")
    # for ddf in output_dataframes.values():
    #     for axis in DataAxis.all_axes.values():
    #         value = ddf[axis.label].values.compute()
    #         print(axis.label, np.nanmin(value), np.nanmax(value))

    log.info(": complete")
    return output_dataframes, output_subsets, total_num_points
Exemplo n.º 10
0
def create_plot(ddf, index_subsets, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, normalize,
                xlabel, ylabel, title, pngname,
                extra_markup=[],
                min_alpha=40, saturate_percentile=None, saturate_alpha=None,
                minmax_cache=None,
                options=None):

    figx = options.xcanvas / 60
    figy = options.ycanvas / 60
    bgcol = "#" + options.bgcol.lstrip("#")

    xaxis = xdatum.label
    yaxis = ydatum.label
    aaxis = adatum and adatum.label
    caxis = cdatum and cdatum.label

    color_key = color_labels = color_minmax = agg_alpha = None

    # do we need to compute any axis min/max?
    bounds = OrderedDict({xaxis: xdatum.minmax, yaxis: ydatum.minmax})
    unknown = []
    for datum in xdatum, ydatum, cdatum:
        if datum is not None:
            bounds[datum.label] = datum.minmax
            if datum.minmax[0] is None or datum.minmax[1] is None:
                if datum.is_discrete and datum.subset_indices is not None:
                    bounds[datum.label] = 0, len(datum.subset_indices)-1
                else:
                    unknown.append(datum.label)

    if unknown:
        log.info(f": scanning axis min/max for {' '.join(unknown)}")
        compute_bounds(unknown, bounds, ddf)
        # populate cache
        if minmax_cache is not None:
            minmax_cache.update([(label, bounds[label]) for label in unknown])

    # adjust bounds for discrete axes
    canvas_sizes = []
    for datum, size in (xdatum, options.xcanvas), (ydatum, options.ycanvas):
        if datum.is_discrete:
            bounds[datum.label] = bounds[datum.label][0]-0.5, bounds[datum.label][1]+0.5
            size = int(bounds[datum.label][1]) - int(bounds[datum.label][0]) + 1
        canvas_sizes.append(size)

    # create rendering canvas.
    canvas = datashader.Canvas(canvas_sizes[0], canvas_sizes[1], x_range=bounds[xaxis], y_range=bounds[yaxis])

    if aaxis is not None:
        agg_alpha = getattr(datashader.reductions, ared, None) if ared else datashader.reductions.count
        if agg_alpha is None:
            raise ValueError(f"unknown alpha reduction function {ared}")
        agg_alpha = agg_alpha(aaxis)

    if cdatum is not None:
        # aggregation applied to by()
        agg_by = agg_alpha if agg_alpha else datashader.count()

        # figure out mapping from raster planes to colours
        # after this if-else block, category will be an aggregator instance yielding N categories,
        # color_key will be a list of N colors, and color_label will be a list of N textual labels

        if data_mappers.USE_COUNT_CAT:
            cats = getattr(ddf.dtypes, caxis).categories
            log.debug(f'colourizing using {caxis} categorical, {len(cats)} bins')
            category = caxis
            color_key = dmap[:len(cats)]
            color_labels = list(map(str, cats))
        else:
            if cdatum.is_discrete:
                # make dictionary from index to label, omitting values that are not in the MS subset to begin with
                if cdatum.discretized_labels:
                    active_subset = OrderedDict(enumerate(cdatum.discretized_labels))
                # else make up integer labels on the spot
                else:
                    active_subset = OrderedDict(enumerate(map(str, range(bounds[caxis][1]+1))))
                # Check if the subset needs to be refined, because it is known to be smaller for this dataframe
                if len(cdatum.columns) == 1 and cdatum.columns[0] in index_subsets:
                    df_index_subset = index_subsets[cdatum.columns[0]]
                    if cdatum.subset_remapper is not None:
                        remapper = cdatum.subset_remapper.compute()
                        df_index_subset = set(remapper[x] for x in df_index_subset)
                    active_subset = OrderedDict((idx, active_subset[idx]) for idx in df_index_subset)
                    log.debug(f"subset of indices for this axis is a priori {list(active_subset.keys())}")
                # max known index
                max_index = max(active_subset.keys())
                num_colors = min(cdatum.nlevels, len(dmap))
                color_key = dmap[:num_colors]
                # if we have fewer indices than colour levels, and the max index is sensible, we'll aggregate to one
                # raster slice per index value directly
                if len(active_subset) <= num_colors and max_index < max(num_colors, 256):
                    num_colors = max_index+1
                    log.debug(f"aggregating directly into {max_index+1} categories")
                    category = category_modulo(caxis, max_index+1)
                    color_label_list = {idx: [value] for idx, value in active_subset.items()}
                else:
                    log.debug(f"aggregating modulo {num_colors} categories")
                    category = category_modulo(caxis, num_colors)
                    # each slice maps to, potentially, multiple labels from the subset
                    color_label_list = {i: [active_subset[idx] for idx in range(i, max_index+1, num_colors) if idx in active_subset]
                                        for i in range(num_colors)}
                    # and colors just come from the bottom of the colormap
                    color_dict = dict(enumerate(options.dmap[:num_colors]))
                # convert lists of color labels into strings
                color_labels = ['']*num_colors
                for i, labels in color_label_list.items():
                    if len(labels) < 3:
                        color_labels[i] = ",".join(labels)
                    else:
                        color_labels[i] = ",".join(labels[:2] + ["..."])
            # else we discretize a span of values
            else:
                num_colors = min(cdatum.nlevels, len(bmap))
                log.debug(f'colourizing using {caxis} with {num_colors} bins')
                cmin = bounds[caxis][0]
                cdelta = (bounds[caxis][1] - cmin) / num_colors
                category = category_binning(caxis, cmin, cdelta, num_colors)
                # color labels are bin centres
                bin_centers = [cmin + cdelta*(i+0.5) for i in range(num_colors)]
                # map to colors pulled from entire extent of color map
                color_key = [bmap[(i*len(bmap))//num_colors] for i in range(num_colors)]
                color_labels = [str(bin) for bin in bin_centers]
                log.info(f": aggregating using {num_colors} bins at {' '.join(color_labels)})")

        raster = canvas.points(ddf, xaxis, yaxis, agg=datashader.by(category, agg_by))
        is_integer_raster = np.issubdtype(raster.dtype, np.integer)

        # the binning aggregator accumulates flagged points in an extra raster plane
        if isinstance(category, category_binning):
            if is_integer_raster:
                log.info(f": {raster[..., -1].data.sum():.3g} points were flagged ")
            raster = raster[...,:-1]

        if is_integer_raster:
            non_empty = np.array(raster.any(axis=(0, 1)))
        else:
            non_empty = ~(np.isnan(raster.data).all(axis=(0, 1)))
        if not non_empty.any():
            log.info(": no valid data in plot. Check your flags and/or plot limits.")
            return None

        if cdatum.is_discrete and not data_mappers.USE_COUNT_CAT:
            # discard empty planes
            non_empty = np.where(non_empty)[0]
            raster = raster[..., non_empty]
            # compress colours to bottom of colormap, unless asked to preserve assignments
            if options.dmap_preserve:
                color_key = [color_key[bin] for bin in non_empty]
            else:
                color_key = color_key[:len(non_empty)]
            color_labels =  [color_labels[bin] for bin in non_empty]

        img = datashader.transfer_functions.shade(raster, color_key=color_key, how=normalize, min_alpha=min_alpha)
        # set color_minmax for colorbar
        color_minmax = bounds[caxis]
    else:
        log.debug(f'rasterizing using {ared}')
        raster = canvas.points(ddf, xaxis, yaxis, agg=agg_alpha)
        if not raster.data.any():
            log.info(": no valid data in plot. Check your flags and/or plot limits.")
            return None
        # get min/max cor colorbar
        if aaxis:
            amin, amax = adatum.minmax
            color_minmax = (amin if amin is not None else np.nanmin(raster)), \
                           (amax if amax is not None else np.nanmax(raster))
            color_key = cmap
        log.debug('shading')
        img = datashader.transfer_functions.shade(raster, cmap=cmap, how=normalize, span=color_minmax, min_alpha=min_alpha)

    # resaturate if needed
    if saturate_alpha is not None or saturate_percentile is not None:
        # get alpha channel
        imgval = img.values
        alpha = (imgval >> 24)&255
        nulls = alpha<min_alpha
        alpha -= min_alpha
        if nulls.all():
            log.debug(f"alpha<min_alpha for entire plot -- all data below lower clip perhaps?")
        else:
            #if percentile if specified, use that to override saturate_alpha
            if saturate_alpha is None:
                saturate_alpha = np.percentile(alpha[~nulls], saturate_percentile)
                log.debug(f"using saturation alpha {saturate_alpha} from {saturate_percentile}th percentile")
            else:
                log.debug(f"using explicit saturation alpha {saturate_alpha}")
            # rescale alpha from [min_alpha, saturation_alpha] to [min_alpha, 255]
            saturation_factor = (255. - min_alpha) / (saturate_alpha - min_alpha)
            alpha = min_alpha + alpha*saturation_factor
            alpha[nulls] = 0
            alpha[alpha>255] = 255
            imgval[:] = (imgval & 0xFFFFFF) | alpha.astype(np.uint32)<<24

    if options.spread_pix:
        img = datashader.transfer_functions.dynspread(img, options.spread_thr, max_px=options.spread_pix)
        log.info(f": spreading ({options.spread_thr} {options.spread_pix})")
    rgb = holoviews.RGB(holoviews.operation.datashader.shade.uint32_to_uint8_xr(img))

    log.debug('done')

    # Set plot limits based on data extent or user values for axis labels

    xmin, xmax = bounds[xaxis]
    ymin, ymax = bounds[yaxis]

    log.debug('rendering image')

    fig = pylab.figure(figsize=(figx, figy))
    ax = fig.add_subplot(111, facecolor=bgcol)

    for funcname, args, kwargs in extra_markup:
        getattr(ax, funcname)(*args, **kwargs)

    ax.imshow(X=rgb.data, extent=[xmin, xmax, ymin, ymax],
              aspect='auto', origin='lower', interpolation='nearest')

    ax.set_title("\n".join(textwrap.wrap(title, 90)), loc='center', fontdict=dict(fontsize=options.fontsize))
    ax.set_xlabel(xlabel, fontdict=dict(fontsize=options.fontsize))
    ax.set_ylabel(ylabel, fontdict=dict(fontsize=options.fontsize))
    # ax.plot(xmin,ymin,'.',alpha=0.0)
    # ax.plot(xmax,ymax,'.',alpha=0.0)

    dx, dy = xmax - xmin, ymax - ymin
    ax.set_xlim([xmin - dx/100, xmax + dx/100])
    ax.set_ylim([ymin - dy/100, ymax + dy/100])

    def decimate_list(x, maxel):
        """Helper function to reduce a list to < given max number of elements, dividing it by decimal factors of 2 and 5"""
        factors = 2, 5, 10
        base = divisor = 1
        while len(x)//divisor > maxel:
            for fac in factors:
                divisor = fac*base
                if len(x)//divisor <= maxel:
                    break
            base *= 10
        return x[::divisor]

    ax.tick_params(labelsize=options.fontsize*0.66)

    # max # of tickmarks and labels to draw for discrete axes
    MAXLABELS = 64   # if we have up to this many labels, show them all
    MAXLABELS1 = 32  # if we have >MAXLABELS to show, then sparsify and get below this number
    MAXTICKS = 300   # if total number of points is within this range, draw them as minor tickmarks

    # do we have discrete labels to put on the axes?
    if xdatum.discretized_labels is not None:
        n = len(xdatum.discretized_labels)
        ticks_labels = list(enumerate(xdatum.discretized_labels))
        if n > MAXLABELS:
            ticks_labels = decimate_list(ticks_labels, MAXLABELS1)         # enforce max number of tick labels
        labels = [label for _, label in ticks_labels]
        rot = 90 if max([len(label) for label in xdatum.discretized_labels])*n > 60 else 0
        ax.set_xticks([x[0] for x in ticks_labels])
        ax.set_xticklabels(labels, rotation=rot)
        if len(ticks_labels) < n and n <= MAXTICKS:
            ax.set_xticks(range(n), minor=True)

    if ydatum.discretized_labels is not None:
        n = len(ydatum.discretized_labels)
        ticks_labels = list(enumerate(ydatum.discretized_labels))
        if n > MAXLABELS:
            ticks_labels = decimate_list(ticks_labels, MAXLABELS1)         # enforce max number of tick labels
        labels = [label for _, label in ticks_labels]
        ax.set_yticks([y[0] for y in ticks_labels])
        ax.set_yticklabels(labels)
        if len(ticks_labels) < n and n <= MAXTICKS:
            ax.set_yticks(range(n), minor=True)

    # colorbar?
    if color_minmax:
        import matplotlib.colors
        # discrete axis
        if caxis is not None and cdatum.is_discrete:
            norm = matplotlib.colors.Normalize(-0.5, len(color_key)-0.5)
            ticks = np.arange(len(color_key))
            colormap = matplotlib.colors.ListedColormap(color_key)
        # discretized axis
        else:
            norm = matplotlib.colors.Normalize(*color_minmax)
            colormap = matplotlib.colors.ListedColormap(color_key)
            # auto-mark colorbar, since it represents a continuous range of values
            ticks = None

        cb = fig.colorbar(matplotlib.cm.ScalarMappable(norm=norm, cmap=colormap), ax=ax, ticks=ticks)

        # adjust ticks for discrete axis
        if caxis is not None and cdatum.is_discrete:
            rot = 0
            # adjust fontsize for number of labels
            fs = max(options.fontsize*min(0.8, 20./len(color_labels)), 6)
            fontdict = dict(fontsize=fs)
            if max([len(lbl) for lbl in color_labels]) > 3 and len(color_labels) < 8:
                rot = 90
                fontdict['verticalalignment'] ='center'
            cb.ax.set_yticklabels(color_labels, rotation=rot, fontdict=fontdict)
        else:
            cb.ax.tick_params(labelsize=options.fontsize*0.8)

    fig.savefig(pngname, bbox_inches='tight')

    pylab.close()

    return pngname
Exemplo n.º 11
0
    def __init__(self,
                 column,
                 function,
                 corr,
                 ms,
                 minmax=None,
                 ncol=None,
                 label=None,
                 subset=None):
        """See register() class method above. Not called directly."""
        self.name = ":".join([
            str(x) for x in (function, column, corr, minmax, ncol)
            if x is not None
        ])
        self.ms = ms
        self.function = function  # function to apply to column (see list of DataMappers below)
        self.corr = corr if corr != "all" else None
        self.nlevels = ncol
        self.minmax = tuple(minmax) if minmax is not None else (None, None)
        self._minmax_autorange = (self.minmax == (None, None))

        self.label = label
        self._corr_reduce = None
        self._is_discrete = None
        self.mapper = data_mappers[function]

        # if set, axis is discrete and labelled
        self.discretized_labels = None

        # for discrete axes: if a subset of N indices is explicitly selected for plotting, then this
        # is a list of the selected indices, of length N
        self.subset_indices = None
        # ...and this is a dask array that maps selected indices into bins 0...N-1, and all other values into bin N
        self.subset_remapper = None
        # ...and this is the maximum valid index in MS
        maxind = None

        # columns with labels?
        if function == 'CORR' or function == 'STOKES':
            corrset = set(subset.corr.names)
            corrset1 = corrset - set("IQUV")
            if corrset1 == corrset:
                name = "Correlation"
            elif not corrset1:
                name = "Stokes"
            else:
                name = "Correlation or Stokes"
            # special case of "corr" if corr is fixed: return constant value fixed here
            # When corr is fixed, we're creating a mapper for a know correlation. When corr is not fixed,
            # we're creating one for a mapper that will iterate over correlations
            if corr is not None:
                self.mapper = DataMapper(name,
                                         "",
                                         column=False,
                                         axis=-1,
                                         mapper=lambda x: corr)
            self.subset_indices = subset.corr
            maxind = ms.all_corr.numbers[-1]
        elif column == "FIELD_ID":
            self.subset_indices = subset.field
            maxind = ms.field.numbers[-1]
        elif column == "ANTENNA1" or column == "ANTENNA2":
            self.subset_indices = subset.ant
            maxind = ms.antenna.numbers[-1]
        elif column == "SCAN_NUMBER":
            self.subset_indices = subset.scan
            maxind = ms.scan.numbers[-1]
        elif function == "BASELINE":
            self.subset_indices = subset.baseline
            maxind = ms.baseline.numbers[-1]
        elif function == "BASELINE_M":
            bl_subset = set(subset.baseline.numbers)  # active baselines
            numbers = [i for i in ms.baseline_m.numbers if i in bl_subset]
            names = [
                bl for i, bl in zip(ms.baseline_m.numbers, ms.baseline_m.names)
                if i in bl_subset
            ]
            self.subset_indices = ms_info.NamedList("baseline_m", names,
                                                    numbers)
            maxind = ms.baseline.numbers[-1]
        elif column == "FLAG" or column == "FLAG_ROW":
            self.discretized_labels = ["F", "T"]

        # make a remapper
        if self.subset_indices is not None:
            # if the mapping from indices to bins 1:1?
            subind = np.array(self.subset_indices.numbers)
            identity = subind[0] == 0 and ((subind[1:] - subind[:-1])
                                           == 1).all()
            # If mapping is not 1:1, or subset is short of full set, then we need a remapper.
            # Map indices in subset into their ordinal numbers in the subset (0...N-1), and all other indices to N
            if len(self.subset_indices) < maxind + 1 or not identity:
                remapper = np.full(maxind + 1, len(self.subset_indices))
                for i, index in enumerate(self.subset_indices.numbers):
                    remapper[index] = i
                self.subset_remapper = da.array(remapper)
            self.discretized_labels = self.subset_indices.names
            self.subset_indices = self.subset_indices.numbers

        if self.discretized_labels:
            self._is_discrete = True

        # axis name
        self.fullname = self.mapper.fullname or column or ''

        # if labels were set up, adjust nlevels (but only down, never up)
        if self.discretized_labels is not None and self.nlevels is not None:
            self.nlevels = min(len(self.discretized_labels), self.nlevels)

        if self.function == "_":
            self.function = ""
        self.conjugate = self.mapper.conjugate
        self.timefreq_axis = self.mapper.axis

        # setup columns
        self._ufunc = None
        self.columns = ()

        # does the mapper have no column (i.e. frequency)?
        if self.mapper.column is False:
            log.info(
                f'axis: {function}, range {self.minmax}, discretization {self.nlevels}'
            )
        # does the mapper have a fixed column? This better be consistent
        elif self.mapper.column is not None:
            # if a mapper (such as "uv") implies a fixed column name, make sure it's consistent with what the user said
            if column and self.mapper.column != column:
                raise ValueError(
                    f"'{function}' not applicable with column {column}")
            self.columns = (column, )
            log.info(
                f'axis: {function}({column}), range {self.minmax}, discretization {self.nlevels}'
            )
        # else arbitrary column
        else:
            log.info(
                f'axis: {function}({column}), corr {self.corr}, range {self.minmax}, discretization {self.nlevels}'
            )
            # check for column arithmetic
            match = re.fullmatch(r"(\w+)([*/+-])(\w+)", column)
            if match:
                self.columns = (match.group(1), match.group(3))
                # look up dask ufunc corresponding to arithmetic op
                self._ufunc = {
                    '+': da.add,
                    '*': da.multiply,
                    '-': da.subtract,
                    '/': da.divide
                }[match.group(2)]
            else:
                self.columns = (column, )