Esempio n. 1
0
 def test_label_to_latex(self):
     assert label_to_latex(None) == ''
     assert label_to_latex('') == ''
     assert label_to_latex('Test') == 'Test'
     assert label_to_latex('Test_with_underscore') == (
         r'Test\_with\_underscore')
     assert label_to_latex(r'Test_with\_escaped\%characters') == (
         r'Test\_with\_escaped\%characters')
Esempio n. 2
0
 def test_label_to_latex(self):
     assert label_to_latex(None) == ''
     assert label_to_latex('') == ''
     assert label_to_latex('Test') == 'Test'
     assert label_to_latex('Test_with_underscore') == (
         r'Test\_with\_underscore')
     assert label_to_latex(r'Test_with\_escaped\%characters') == (
         r'Test\_with\_escaped\%characters')
Esempio n. 3
0
    def process_svg(self, outputfile):
        for ax in self.plot.axes:
            for line in ax.lines:
                line.set_rasterized(True)

        # render image
        super(SvgMixin, self).finalize(
            outputfile=outputfile.replace('.svg', '.png'), close=False)

        # make new text labels for the channel names
        ax = self.plot.axes[0]
        leg = ax.legend_
        texts = []
        if leg is not None:
            for i, (text, line) in enumerate(
                    zip(leg.get_texts(), leg.get_lines())):
                try:
                    label, source = re_source_label.match(
                        text.get_text()).groups()
                except (AttributeError, ValueError):
                    continue
                channel = label_to_latex(str(source))
                text.set_text(label)
                t2 = ax.text(
                    0.994, 1.02, channel, ha='right', va='bottom',
                    fontsize=text.get_fontsize(), zorder=1000,
                    transform=ax.transAxes,
                    bbox={'facecolor': 'white', 'edgecolor': 'lightgray',
                          'pad': 10.})
                text.set_gid('leg_text_%d' % i)
                line.set_gid('leg_patch_%d' % i)
                t2.set_gid('label_%d' % i)
                texts.append(t2)

        # tmp save
        f = StringIO()
        self.plot.save(f, format='svg')

        # parse svg
        tree, xmlid = etree.XMLID(f.getvalue())
        tree.set('onload', 'init(evt)')

        # add effects
        for i in range(len(texts)):
            pl = xmlid['leg_patch_%d' % i]
            ll = xmlid['leg_text_%d' % i]
            tl = xmlid['label_%d' % i]
            pl.set('cursor', 'pointer')
            pl.set('onmouseover', "ShowLabel(this)")
            pl.set('onmouseout', "HideLabel(this)")
            ll.set('cursor', 'pointer')
            ll.set('onmouseover', "ShowLabel(this)")
            ll.set('onmouseout', "HideLabel(this)")
            tl.set('class', 'mpl-label')
            tl.set('visibility', 'hidden')

        return self.finalize_svg(tree, outputfile, script=HOVERSCRIPT)
Esempio n. 4
0
    def _process(self):
        """Load all data, and generate this `SpectrumDataPlot`
        """
        plot = self.plot = SpectrumPlot(
            figsize=self.pargs.pop('figsize', [12, 6]))
        ax = plot.gca()

        if self.state:
            self.pargs.setdefault(
                'suptitle',
                '[%s-%s, state: %s]' % (self.span[0], self.span[1],
                                        label_to_latex(str(self.state))))
        suptitle = self.pargs.pop('suptitle', None)
        if suptitle:
            plot.suptitle(suptitle, y=0.993, va='top')

        # parse plotting arguments
        cmap = self.pargs.pop('cmap', None)
        varargs = self.parse_variance_kwargs()
        plotargs = self.parse_plot_kwargs()[0]
        legendargs = self.parse_legend_kwargs()

        # get reference arguments
        refs = []
        refkey = 'None'
        for key in sorted(self.pargs.keys()):
            if key == 'reference' or re.match('reference\d+\Z', key):
                refs.append(dict())
                refs[-1]['source'] = self.pargs.pop(key)
                refkey = key
            if re.match('%s[-_]' % refkey, key):
                refs[-1][key[len(refkey)+1:]] = self.pargs.pop(key)

        # get channel arguments
        if hasattr(self.channels[0], 'asd_range'):
            low, high = self.channels[0].asd_range
            varargs.setdefault('low', low)
            varargs.setdefault('high', high)

        # calculate spectral variance and plot
        # pad data request to over-fill plots (no gaps at the end)
        if self.state and not self.all_data:
            valid = self.state.active
        else:
            valid = SegmentList([self.span])
        livetime = float(abs(valid))

        if livetime:
            plotargs.setdefault('vmin', 1/livetime)
        plotargs.setdefault('vmax', 1.)
        plotargs.pop('label')

        specgram = get_spectrogram(self.channels[0], valid, query=False,
                                    format='asd').join(gap='ignore')

        if specgram.size:
            asd = specgram.median(axis=0)
            asd.name = None
            variance = specgram.variance(**varargs)
            # normalize the variance
            variance /= livetime / specgram.dt.value
            # plot
            ax.plot(asd, color='grey', linewidth=0.3)
            m = ax.plot_variance(variance, cmap=cmap, **plotargs)
        #else:
        #    ax.scatter([1], [1], c=[1], visible=False, vmin=plotargs['vmin'],
        #               vmax=plotargs['vmax'], cmap=plotargs['cmap'])
        #plot.add_colorbar(ax=ax, log=True, label='Fractional time at amplitude')

        # allow channel data to set parameters
        if getattr(self.channels[0], 'frequency_range', None) is not None:
            self.pargs.setdefault('xlim', self.channels[0].frequency_range)
            if isinstance(self.pargs['xlim'], Quantity):
                self.pargs['xlim'] = self.pargs['xlim'].value
        if hasattr(self.channels[0], 'asd_range'):
            self.pargs.setdefault('ylim', self.channels[0].asd_range)

        # display references
        for i, ref in enumerate(refs):
            if 'source' in ref:
                source = ref.pop('source')
                try:
                    refspec = Spectrum.read(source)
                except IOError as e:
                    warnings.warn('IOError: %s' % str(e))
                except Exception as e:
                    # hack for old versions of GWpy
                    # TODO: remove me when GWSumm requires GWpy > 0.1
                    if 'Format could not be identified' in str(e):
                        refspec = Spectrum.read(source, format='dat')
                    else:
                        raise
                else:
                    if 'filter' in ref:
                        refspec = refspec.filter(*ref.pop('filter'))
                    if 'scale' in ref:
                        refspec *= ref.pop('scale', 1)
                    ax.plot(refspec, **ref)

        # customise
        hlines = list(self.pargs.pop('hline', []))
        for key, val in self.pargs.iteritems():
            try:
                getattr(ax, 'set_%s' % key)(val)
            except AttributeError:
                setattr(ax, key, val)

        # add horizontal lines to add
        if hlines:
            if not isinstance(hlines[-1], float):
                lineparams = hlines.pop(-1)
            else:
                lineparams = {'color':'r', 'linestyle': '--'}
        for yval in hlines:
            try:
                yval = float(yval)
            except ValueError:
                continue
            else:
                ax.plot(ax.get_xlim(), [yval, yval], **lineparams)

        # set grid
        ax.grid(b=True, axis='both', which='both')

        if not plot.colorbars:
            plot.add_colorbar(ax=ax, visible=False)

        return self.finalize()
Esempio n. 5
0
    def process(self, outputfile=None):
        """Get data and generate the figure.
        """
        # get plot and axes
        (plot, axes) = self.init_plot()

        if self.state:
            self.pargs.setdefault(
                'suptitle',
                '[%s-%s, state: %s]' % (self.span[0], self.span[1],
                                        label_to_latex(str(self.state))))
        suptitle = self.pargs.pop('suptitle', None)
        if suptitle:
            plot.suptitle(suptitle, y=0.993, va='top')

        # extract histogram arguments
        histargs = self.parse_plot_kwargs()

        # get data
        data = []
        for channel in self.channels:
            if self.state and not self.all_data:
                valid = self.state.active
            else:
                valid = SegmentList([self.span])
            data.append(get_timeseries(channel, valid, query=False).join(
                gap='ignore', pad=numpy.nan))
            # allow channel data to set parameters
            if hasattr(data[-1].channel, 'amplitude_range'):
                self.pargs.setdefault('xlim', data[-1].channel.amplitude_range)

        # get range
        if not 'range' in histargs[0]:
            l = axes[0].common_limits(data)
            for d in histargs:
                d['range'] = l

        # plot
        for ax, arr, pargs in zip(cycle(axes), data, histargs):
            if arr.size == 0:
                kwargs = dict(
                    (k, pargs[k]) for k in ['label', 'color'] if pargs.get(k))
                ax.plot([], **kwargs)
            else:
                ax.hist(arr, **pargs)

        # customise plot
        legendargs = self.parse_legend_kwargs()
        for i, ax in enumerate(axes):
            for key, val in self.pargs.iteritems():
                if key == 'title' and i > 0:
                    continue
                if key == 'xlabel' and i < (len(axes) - 1):
                    continue
                if key == 'ylabel' and (
                        (len(axes) % 2 and i != len(axes) // 2) or
                        (len(axes) % 2 == 0 and i > 0)):
                    continue
                try:
                    getattr(ax, 'set_%s' % key)(val)
                except AttributeError:
                    setattr(ax, key, val)
            if len(self.channels) > 1:
                    plot.add_legend(ax=ax, **legendargs)
        if len(axes) % 2 == 0 and axes[0].get_ylabel():
            label = axes[0].yaxis.label
            ax = axes[int(len(axes) // 2)-1]
            ax.set_ylabel(label.get_text())
            ax.yaxis.label.set_position((0, -.2 / len(axes)))
            if len(axes) != 2:
                label.set_text('')

        # add extra axes and finalise
        if not plot.colorbars:
            for ax in axes:
                plot.add_colorbar(ax=ax, visible=False)
        return self.finalize(outputfile=outputfile)
Esempio n. 6
0
    def process(self, outputfile=None):
        """Get data and generate the figure.
        """
        # get histogram parameters
        plot, axes = self.init_plot()
        ax = axes[0]

        if self.state:
            self.pargs.setdefault(
                'suptitle',
                '[%s-%s, state: %s]' % (self.span[0], self.span[1],
                                        label_to_latex(str(self.state))))
        suptitle = self.pargs.pop('suptitle', None)
        if suptitle:
            plot.suptitle(suptitle, y=0.993, va='top')
        # get data
        data = []
        for channel in self.channels:
            if self.state and not self.all_data:
                valid = self.state.active
            else:
                valid = SegmentList([self.span])
            data.append(get_timeseries(channel, valid, query=False).join(
                gap='ignore', pad=numpy.nan))
        if len(data) == 1:
            data.append(data[0])
        # allow channel data to set parameters
        self.pargs.setdefault('xlabel', label_to_latex(data[0].name))
        self.pargs.setdefault('ylabel', label_to_latex(data[1].name))
        if hasattr(data[0].channel, 'amplitude_range'):
            self.pargs.setdefault('xlim', data[0].channel.amplitude_range)
        if hasattr(data[1].channel, 'amplitude_range'):
            self.pargs.setdefault('ylim', data[1].channel.amplitude_range)
        # histogram
        hist_kwargs = self.parse_hist_kwargs()
        h, xedges, yedges = numpy.histogram2d(data[0], data[1],
                                              **hist_kwargs)
        h = numpy.ma.masked_where(h==0, h)
        x, y = numpy.meshgrid(xedges, yedges, copy=False, sparse=True)
        # plot
        pcmesh_kwargs = self.parse_pcmesh_kwargs()
        ax.pcolormesh(x, y, h.T, **pcmesh_kwargs)
        # customise plot
        for key, val in self.pargs.iteritems():
            try:
                getattr(ax, 'set_%s' % key)(val)
            except AttributeError:
                if key == 'grid':
                    if val == 'off':
                        ax.grid('off')
                    elif val in ['both', 'major', 'minor']:
                        ax.grid('on', which=val)
                    else:
                        raise ValueError("Invalid ax.grid argument; "
                                         "valid options are: 'off', "
                                         "'both', 'major' or 'minor'")
                else:
                    setattr(ax, key, val)
        # add extra axes and finalise
        if not plot.colorbars:
            plot.add_colorbar(ax=ax, visible=False)
        return self.finalize(outputfile=outputfile)
Esempio n. 7
0
    def annotate_save_plot(self, arg_list):
        """After the derived class generated a plot
        object finish the process"""
        from astropy.time import Time
        from gwpy.plotter.tex import label_to_latex

        self.ax = self.plot.gca()
        # set up axes
        self.setup_xaxis(arg_list)
        self.setup_yaxis(arg_list)
        self.setup_iaxis(arg_list)

        if self.is_image():
            if arg_list.nocolorbar:
                self.plot.add_colorbar(visible=False)
            else:
                self.plot.add_colorbar(label=self.get_color_label())
        else:
            self.plot.add_colorbar(visible=False)

        # image plots don't have legends
        if not self.is_image():
            leg = self.ax.legend(prop={'size': 10})
            # if only one series is plotted hide legend
            if self.n_datasets == 1 and leg:
                try:
                    leg.remove()
                except NotImplementedError:
                    leg.set_visible(False)

        # add titles
        title = ''
        if arg_list.title:
            for t in arg_list.title:
                if len(title) > 0:
                    title += "\n"
                title += t
        # info on the processing
        start = self.start_list[0]
        startGPS = Time(start, format='gps')
        timeStr = "%s - %10d (%ds)" % (startGPS.iso, start, self.dur)

        # list the different sample rates available in all time series
        fs_set = set()

        for idx in range(0, len(self.timeseries)):
            fs = self.timeseries[idx].sample_rate
            fs_set.add(fs)

        fs_str = ''
        for fs in fs_set:
            if len(fs_str) > 0:
                fs_str += ', '
            fs_str += '(%s)' % fs

        if self.is_freq_plot:
            spec = r'%s, Fs=%s, secpfft=%.1f (bw=%.3f), overlap=%.2f' %  \
                    (timeStr, fs_str, self.secpfft, 1/self.secpfft,
                     self.overlap)
        else:
            xdur = self.xmax - self.xmin
            spec = r'Fs=%s, duration: %.1f' % (fs_str, xdur)
        spec += ", " + self.filter
        if len(title) > 0:
            title += "\n"
        title += spec

        title = label_to_latex(title)
        self.plot.set_title(title, fontsize=12)
        self.log(3, ('Title is: %s' % title))

        xlabel = ''
        if arg_list.xlabel:
            xlabel = label_to_latex(arg_list.xlabel)
        else:
            xlabel = self.get_xlabel()
        if xlabel:
            self.plot.set_xlabel(xlabel)
            self.log(3, ('X-axis label is: %s' % xlabel))

        if arg_list.ylabel:
            ylabel = label_to_latex(arg_list.ylabel)
        else:
            ylabel = self.get_ylabel(arg_list)

        if ylabel:
            self.plot.set_ylabel(ylabel)
            self.log(3, ('Y-axis label is: %s' % ylabel))

        if not arg_list.nogrid:
            self.ax.grid(b=True, which='major', color='k', linestyle='solid')
            self.ax.grid(b=True, which='minor', color='0.06',
                         linestyle='dotted')

        # info on the channel
        if arg_list.suptitle:
            sup_title = arg_list.suptitle
        else:
            sup_title = self.get_sup_title()
        sup_title = label_to_latex(sup_title)
        self.plot.suptitle(sup_title, fontsize=18)

        self.log(3, ('Super title is: %s' % sup_title))
        self.show_plot_info()

        # change the label for GPS time so Josh is happy
        if self.ax.get_xscale() == 'auto-gps':
            import re
            xscale = self.ax.xaxis._scale
            epoch = xscale.get_epoch()
            unit = xscale.get_unit_name()
            utc = re.sub('\.0+', '',
                         Time(epoch, format='gps', scale='utc').iso)
            self.plot.set_xlabel('Time (%s) from %s (%s)' % (unit, utc, epoch))
            self.ax.xaxis._set_scale(unit, epoch=epoch)

        # if they specified an output file write it
        # save the figure. Note type depends on extension of
        # output filename (png, jpg, pdf)
        if arg_list.out:
            out_file = arg_list.out
        else:
            out_file = "./gwpy.png"

        self.log(3, ('xinch: %.2f, yinch: %.2f, dpi: %d' %
                     (self.xinch, self.yinch, self.dpi)))

        self.plot.savefig(out_file, edgecolor='white',
                          figsize=[self.xinch, self.yinch], dpi=self.dpi)
        self.log(3, ('wrote %s' % arg_list.out))

        return
Esempio n. 8
0
    def draw(self, outputfile=None):
        """Read in all necessary data, and generate the figure.
        """
        (plot, axes) = self.init_plot()
        ax = axes[0]

        plotargs = self.parse_plot_kwargs()
        legendargs = self.parse_legend_kwargs()

        # add data
        channels, groups = zip(*self.get_channel_groups())
        for clist, pargs in zip(groups, plotargs):
            # pad data request to over-fill plots (no gaps at the end)
            if self.state and not self.all_data:
                valid = self.state.active
            elif clist[0].sample_rate:
                valid = SegmentList(
                    [self.span.protract(1 / clist[0].sample_rate.value)])
            else:
                valid = SegmentList([self.span])
            # get data
            data = [get_timeseries(c, valid, query=False) for c in clist]
            if len(clist) > 1:
                data = [tsl.join(gap='pad', pad=numpy.nan) for tsl in data]
            flatdata = [ts for tsl in data for ts in tsl]
            # validate parameters
            for ts in flatdata:
                # double-check empty
                if (hasattr(ts, 'metadata')
                        and 'x0' not in ts.metadata) or not ts.x0:
                    ts.epoch = self.start
                # double-check log scales
                if self.pargs.get('logy', False):
                    ts.value[ts.value == 0] = 1e-100
            # set label
            try:
                label = pargs.pop('label')
            except KeyError:
                try:
                    label = label_to_latex(flatdata[0].name)
                except IndexError:
                    label = clist[0]
                else:
                    if self.fileformat == 'svg' and not label.startswith(
                            label_to_latex(str(
                                flatdata[0].channel)).split('.')[0]):
                        label += ' [%s]' % (label_to_latex(
                            str(flatdata[0].channel)))
            # plot groups or single TimeSeries
            if len(clist) > 1:
                ax.plot_timeseries_mmm(*data, label=label, **pargs)
            elif len(flatdata) == 0:
                ax.plot_timeseries(data[0].EntryClass([],
                                                      epoch=self.start,
                                                      unit='s',
                                                      name=label),
                                   label=label,
                                   **pargs)
            else:
                for ts in data[0]:
                    line = ax.plot_timeseries(ts, label=label, **pargs)[0]
                    label = None
                    pargs['color'] = line.get_color()

            # allow channel data to set parameters
            if len(flatdata):
                chan = get_channel(flatdata[0].channel)
            else:
                chan = get_channel(clist[0])
            if getattr(chan, 'amplitude_range', None) is not None:
                self.pargs.setdefault('ylim', chan.amplitude_range)

        # add horizontal lines to add
        for yval in self.pargs.get('hline', []):
            try:
                yval = float(yval)
            except ValueError:
                continue
            else:
                ax.plot([self.start, self.end], [yval, yval],
                        linestyle='--',
                        color='red')

        # customise plot
        self.apply_parameters(ax, **self.pargs)

        if (len(channels) > 1 or plotargs[0].get('label', None)
                in [re.sub(r'(_|\\_)', r'\_', channels[0]), None]):
            plot.add_legend(ax=ax, **legendargs)

        # add extra axes and finalise
        if not plot.colorbars:
            plot.add_colorbar(ax=ax, visible=False)

        self.add_state_segments(ax)
        return self.finalize(outputfile=outputfile)
Esempio n. 9
0
    def gen_plot(self, args):
        """Generate the plot from time series and arguments
        """
        self.max_size = 16384. * 6400.  # that works on my mac
        self.yscale_factor = 1.0

        from gwpy.plotter.tex import label_to_latex
        from numpy import min as npmin
        from numpy import max as npmax

        if self.timeseries[0].size <= self.max_size:
            self.plot = self.timeseries[0].plot()
        else:
            self.plot = self.timeseries[0].plot(linestyle='None', marker='.')
        self.ymin = self.timeseries[0].min().value
        self.ymax = self.timeseries[0].max().value
        self.xmin = self.timeseries[0].times.value.min()
        self.xmax = self.timeseries[0].times.value.max()

        if len(self.timeseries) > 1:
            for idx in range(1, len(self.timeseries)):
                chname = self.timeseries[idx].channel.name
                lbl = label_to_latex(chname)
                if self.timeseries[idx].size <= self.max_size:
                    self.plot.add_timeseries(self.timeseries[idx], label=lbl)
                else:
                    self.plot.add_timeseries(self.timeseries[idx], label=lbl,
                                             linestyle='None', marker='.')
                self.ymin = min(self.ymin, self.timeseries[idx].min().value)
                self.ymax = max(self.ymax, self.timeseries[idx].max().value)
                self.xmin = min(self.xmin,
                                self.timeseries[idx].times.value.min())
                self.xmax = max(self.xmax,
                                self.timeseries[idx].times.value.max())
        # if they chose to set the range of the x-axis find the range of y
        strt = self.xmin
        stop = self.xmax
        # a bit weird but global ymax will be >= any value in
        # the range same for ymin
        new_ymin = self.ymax
        new_ymax = self.ymin

        if args.xmin:
            strt = float(args.xmin)
        if args.xmax:
            stop = float(args.xmax)
        if strt != self.xmin or stop != self.xmax:
            for idx in range(0, len(self.timeseries)):
                x0 = self.timeseries[idx].x0.value
                dt = self.timeseries[idx].dt.value
                if strt < 1e8:
                    strt += x0
                if stop < 1e8:
                    stop += x0
                b = int(max(0, (strt - x0) / dt))

                e = int(min(self.xmax, (stop - x0) / dt))

                if e >= self.timeseries[idx].size:
                    e = self.timeseries[idx].size - 1
                new_ymin = min(new_ymin,
                               npmin(self.timeseries[idx].value[b:e]))
                new_ymax = max(new_ymax,
                               npmax(self.timeseries[idx].value[b:e]))
            self.ymin = new_ymin
            self.ymax = new_ymax
        if self.yscale_factor > 1:
            self.log(2, ('Scaling y-limits, original: %f, %f)' %
                         (self.ymin, self.ymax)))
            yrange = self.ymax - self.ymin
            mid = (self.ymax + self.ymin) / 2.
            self.ymax = mid + yrange / (2 * self.yscale_factor)
            self.ymin = mid - yrange / (2 * self.yscale_factor)
            self.log(2, ('Scaling y-limits, new: %f, %f)' %
                         (self.ymin, self.ymax)))
Esempio n. 10
0
    def draw(self, outputfile=None):
        # make axes
        (plot, axes) = self.init_plot()

        # use state to generate suptitle with GPS span
        if self.state:
            self.pargs.setdefault(
                'suptitle',
                '[%s-%s, state: %s]' % (self.span[0], self.span[1],
                                        label_to_latex(str(self.state))))
        else:
            self.pargs.setdefault(
                'suptitle', '[%s-%s]' % (self.span[0], self.span[1]))
        suptitle = self.pargs.pop('suptitle', None)
        if suptitle:
            plot.suptitle(suptitle, y=0.993, va='top')

        # extract plotting arguments
        histargs = self.parse_plot_kwargs()

        # get segments
        data = []
        for flag in self.flags:
            if self.state and not self.all_data:
                valid = self.state.active
            else:
                valid = SegmentList([self.span])
            segs = get_segments(flag, validity=valid, query=False,
                                padding=self.padding).coalesce()
            livetime = float(abs(segs.active))
            data.append(map(lambda x: float(abs(x)), segs.active))

        # get range
        if not 'range' in histargs[0]:
            l = axes[0].common_limits(data)
            for d in histargs:
                d['range'] = l

        # plot
        for ax, arr, pargs in zip(cycle(axes), data, histargs):
            if len(arr) == 0:
                kwargs = dict(
                    (k, pargs[k]) for k in ['label', 'color'] if pargs.get(k))
                ax.plot([], **kwargs)
            else:
                if pargs.get('normed', False) in ['N', 'num', 'number']:
                    pargs['normed'] = False
                    pargs.setdefault('weights', [1/len(arr)] * len(arr))
                ax.hist(arr, **pargs)

        # customise plot
        legendargs = self.parse_legend_kwargs()
        for i, ax in enumerate(axes):
            for key, val in self.pargs.iteritems():
                if key == 'title' and i > 0:
                    continue
                if key == 'xlabel' and i < (len(axes) - 1):
                    continue
                if key == 'ylabel' and (
                        (len(axes) % 2 and i != len(axes) // 2) or
                        (len(axes) % 2 == 0 and i > 0)):
                    continue
                try:
                    getattr(ax, 'set_%s' % key)(val)
                except AttributeError:
                    setattr(ax, key, val)
            if len(self.flags) > 1:
                plot.add_legend(ax=ax, **legendargs)
        if len(axes) % 2 == 0 and axes[0].get_ylabel():
            label = axes[0].yaxis.label
            ax = axes[int(len(axes) // 2)-1]
            ax.set_ylabel(label.get_text())
            ax.yaxis.label.set_position((0, -.2 / len(axes)))
            if len(axes) != 2:
                label.set_text('')

        # set common ylim
        if 'ylim' not in self.pargs:
            y0 = min([ax.get_ylim()[0] for ax in axes])
            y1 = max([ax.get_ylim()[1] for ax in axes])
            for ax in axes:
                ax.set_ylim(y0, y1)

        # add bit mask axes and finalise
        return self.finalize(outputfile=outputfile, transparent="True",
                             pad_inches=0)
Esempio n. 11
0
    def process_svg(self, outputfile):
        for ax in self.plot.axes:
            for line in ax.lines:
                line.set_rasterized(True)

        # render image
        super(SvgMixin,
              self).finalize(outputfile=outputfile.replace('.svg', '.png'),
                             close=False)

        # make new text labels for the channel names
        ax = self.plot.axes[0]
        leg = ax.legend_
        texts = []
        if leg is not None:
            for i, (text,
                    line) in enumerate(zip(leg.get_texts(), leg.get_lines())):
                try:
                    label, source = re_source_label.match(
                        text.get_text()).groups()
                except (AttributeError, ValueError):
                    continue
                channel = label_to_latex(str(source))
                text.set_text(label)
                t2 = ax.text(0.994,
                             1.02,
                             channel,
                             ha='right',
                             va='bottom',
                             fontsize=text.get_fontsize(),
                             zorder=1000,
                             transform=ax.transAxes,
                             bbox={
                                 'facecolor': 'white',
                                 'edgecolor': 'lightgray',
                                 'pad': 10.
                             })
                text.set_gid('leg_text_%d' % i)
                line.set_gid('leg_patch_%d' % i)
                t2.set_gid('label_%d' % i)
                texts.append(t2)

        # tmp save
        f = StringIO()
        self.plot.save(f, format='svg')

        # parse svg
        tree, xmlid = etree.XMLID(f.getvalue())
        tree.set('onload', 'init(evt)')

        # add effects
        for i in range(len(texts)):
            pl = xmlid['leg_patch_%d' % i]
            ll = xmlid['leg_text_%d' % i]
            tl = xmlid['label_%d' % i]
            pl.set('cursor', 'pointer')
            pl.set('onmouseover', "ShowLabel(this)")
            pl.set('onmouseout', "HideLabel(this)")
            ll.set('cursor', 'pointer')
            ll.set('onmouseover', "ShowLabel(this)")
            ll.set('onmouseout', "HideLabel(this)")
            tl.set('class', 'mpl-label')
            tl.set('visibility', 'hidden')

        return self.finalize_svg(tree, outputfile, script=HOVERSCRIPT)
Esempio n. 12
0
    def init_plots(self, plotdir=os.curdir):
        """Configure the default list of plots for this tab

        This method configures a veto-trigger glitchgram, histograms of
        before/after SNR and frequency/template duration,
        before and after glitchgrams, and a segment plot.

        This method is a mess, and should be re-written in a better way.
        """
        if self.intersection:
            label = 'Intersection'
        else:
            label = 'Union'

        etgstr = self.etg.replace('_', r'\\_')

        self.set_layout([1,])
        before = get_channel(str(self.channel))
        for state in self.states:
            if self.channel:
                after = get_channel(veto_tag(before, self.metaflag,
                                             mode='after'))
                vetoed = get_channel(veto_tag(before, self.metaflag,
                                              mode='vetoed'))
                # -- configure trigger plots
                params = etg.get_etg_parameters(self.etg, IFO=before.ifo)
                glitchgramargs = {
                    'etg': self.etg,
                    'x': params['time'],
                    'y': params['frequency'],
                    'logy': params.get('frequency-log', True),
                    'ylabel': params.get('frequency-label',
                                         get_column_label(params['frequency'])),
                    'edgecolor': 'none',
                    'legend-scatterpoints': 1,
                    'legend-borderaxespad': 0,
                    'legend-bbox_to_anchor': (1.01, 1),
                    'legend-loc': 'upper left',
                }
                # plot before/after glitchgram
                self.plots.append(get_plot('triggers')(
                    [after, vetoed], self.start, self.end, state=state,
                    title='Impact of %s (%s)' % (
                        label_to_latex(self.name), etgstr),
                    outdir=plotdir, labels=['_', 'Vetoed'],
                    colors=['lightblue', 'red'], **glitchgramargs))

                # plot histograms
                statistics = ['snr']
                if params['det'] != params['snr']:
                    statistics.append('det')
                self.layout.append(len(statistics) + 1)
                for column in statistics + ['frequency']:
                    self.plots.append(get_plot('trigger-histogram')(
                        [before, after], self.start, self.end, state=state,
                        column=params[column], etg=self.etg, outdir=plotdir,
                        title='Impact of %s (%s)' % (
                            label_to_latex(self.name), etgstr),
                        labels=['Before', 'After'],
                        xlabel=params.get('%s-label' % column,
                                          get_column_label(params[column])),
                        color=['red', (0.2, 0.8, 0.2)],
                        logx=params.get('%s-log' % column, True),
                        logy=True,
                        histtype='stepfilled', alpha=0.6,
                        weights=1/float(abs(self.span)), bins=100,
                        ybound=1/float(abs(self.span)) * 0.5, **{
                            'legend-borderaxespad': 0,
                            'legend-bbox_to_anchor': (1.01, 1),
                            'legend-loc': 'upper left'}
                    ))

                # plot triggers before and after
                for stat in statistics:
                    column = params[stat]
                    glitchgramargs.update({
                        'color': column,
                        'clim': params.get('%s-limits' % stat, [3, 100]),
                        'logcolor': params.get('%s-log' % stat, True),
                        'colorlabel': params.get('%s-label' % stat,
                                                 get_column_label(column)),
                    })
                    self.plots.append(get_plot('triggers')(
                        [before], self.start, self.end, state=state,
                        outdir=plotdir, **glitchgramargs))
                    self.plots.append(get_plot('triggers')(
                        [after], self.start, self.end, state=state,
                        title='After %s (%s)' % (
                            label_to_latex(self.name), self.etg),
                        outdir=plotdir, **glitchgramargs))
                    self.layout.append(2)

            # -- configure segment plot
            segargs = {
                'state': state,
                'known': {'alpha': 0.1, 'facecolor': 'lightgray'},
                'color': 'red',
            }
            if len(self.flags) == 1:
                sp = get_plot('segments')(self.flags, self.start, self.end,
                                          outdir=plotdir, labels=self.labels,
                                          **segargs)
            else:
                sp = get_plot('segments')(
                    [self.metaflag] + self.flags, self.start, self.end,
                    labels=([self.intersection and 'Intersection' or 'Union'] +
                            self.labels), outdir=plotdir, **segargs)
            self.plots.append(sp)
            self.layout.append(1)
Esempio n. 13
0
    def _draw(self):
        """Load all data, and generate this `SpectrumDataPlot`
        """
        plot = self.plot = FrequencySeriesPlot(
            figsize=self.pargs.pop('figsize', [12, 6]))
        ax = plot.gca()

        if self.state:
            self.pargs.setdefault(
                'suptitle', '[%s-%s, state: %s]' %
                (self.span[0], self.span[1], label_to_latex(str(self.state))))
        suptitle = self.pargs.pop('suptitle', None)
        if suptitle:
            plot.suptitle(suptitle, y=0.993, va='top')

        # parse plotting arguments
        cmap = self.pargs.pop('cmap', None)
        varargs = self.parse_variance_kwargs()
        plotargs = self.parse_plot_kwargs()[0]
        legendargs = self.parse_legend_kwargs()

        # get reference arguments
        refs = []
        refkey = 'None'
        for key in sorted(self.pargs.keys()):
            if key == 'reference' or re.match('reference\d+\Z', key):
                refs.append(dict())
                refs[-1]['source'] = self.pargs.pop(key)
                refkey = key
            if re.match('%s[-_]' % refkey, key):
                refs[-1][key[len(refkey) + 1:]] = self.pargs.pop(key)

        # get channel arguments
        if hasattr(self.channels[0], 'asd_range'):
            low, high = self.channels[0].asd_range
            varargs.setdefault('low', low)
            varargs.setdefault('high', high)

        # calculate spectral variance and plot
        # pad data request to over-fill plots (no gaps at the end)
        if self.state and not self.all_data:
            valid = self.state.active
        else:
            valid = SegmentList([self.span])
        livetime = float(abs(valid))

        if livetime:
            plotargs.setdefault('vmin', 1 / livetime)
        plotargs.setdefault('vmax', 1.)
        plotargs.pop('label')

        specgram = get_spectrogram(self.channels[0],
                                   valid,
                                   query=False,
                                   format='asd').join(gap='ignore')

        if specgram.size:
            asd = specgram.median(axis=0)
            asd.name = None
            variance = specgram.variance(**varargs)
            # normalize the variance
            variance /= livetime / specgram.dt.value
            # undo demodulation
            variance = undo_demodulation(variance, self.channels[0],
                                         self.pargs.get('xlim', None))

            # plot
            ax.plot(asd, color='grey', linewidth=0.3)
            ax.plot_variance(variance, cmap=cmap, **plotargs)

        # allow channel data to set parameters
        if getattr(self.channels[0], 'frequency_range', None) is not None:
            self.pargs.setdefault('xlim', self.channels[0].frequency_range)
            if isinstance(self.pargs['xlim'], Quantity):
                self.pargs['xlim'] = self.pargs['xlim'].value
        if hasattr(self.channels[0], 'asd_range'):
            self.pargs.setdefault('ylim', self.channels[0].asd_range)

        # display references
        for i, ref in enumerate(refs):
            if 'source' in ref:
                source = ref.pop('source')
                try:
                    refspec = io.read_frequencyseries(source)
                except IOError as e:  # skip if file can't be read
                    warnings.warn('IOError: %s' % str(e))
                else:
                    if 'filter' in ref:
                        refspec = refspec.filter(*ref.pop('filter'))
                    if 'scale' in ref:
                        refspec *= ref.pop('scale', 1)
                    ax.plot(refspec, **ref)

        # customise
        hlines = list(self.pargs.pop('hline', []))
        self.apply_parameters(ax, **self.pargs)

        # add horizontal lines to add
        if hlines:
            if not isinstance(hlines[-1], float):
                lineparams = hlines.pop(-1)
            else:
                lineparams = {'color': 'r', 'linestyle': '--'}
        for yval in hlines:
            try:
                yval = float(yval)
            except ValueError:
                continue
            else:
                ax.plot(ax.get_xlim(), [yval, yval], **lineparams)

        # set grid
        ax.grid(b=True, axis='both', which='both')

        if not plot.colorbars:
            plot.add_colorbar(ax=ax, visible=False)

        return self.finalize()
Esempio n. 14
0
 def test_label_to_latex(self):
     self.assertEqual(label_to_latex("Test"), "Test")
     self.assertEqual(label_to_latex("Test_with_underscore"), r"Test\_with\_underscore")
Esempio n. 15
0
    def draw(self, outputfile=None):
        """Get data and generate the figure.
        """
        # get histogram parameters
        plot, axes = self.init_plot()
        ax = axes[0]

        if self.state:
            self.pargs.setdefault(
                'suptitle', '[%s-%s, state: %s]' %
                (self.span[0], self.span[1], label_to_latex(str(self.state))))
        suptitle = self.pargs.pop('suptitle', None)
        if suptitle:
            plot.suptitle(suptitle, y=0.993, va='top')
        # get data
        data = []
        for channel in self.channels:
            if self.state and not self.all_data:
                valid = self.state.active
            else:
                valid = SegmentList([self.span])
            data.append(
                get_timeseries(channel, valid,
                               query=False).join(gap='ignore', pad=numpy.nan))
        if len(data) == 1:
            data.append(data[0])
        # allow channel data to set parameters
        self.pargs.setdefault('xlabel', label_to_latex(data[0].name))
        self.pargs.setdefault('ylabel', label_to_latex(data[1].name))
        if hasattr(data[0].channel, 'amplitude_range'):
            self.pargs.setdefault('xlim', data[0].channel.amplitude_range)
        if hasattr(data[1].channel, 'amplitude_range'):
            self.pargs.setdefault('ylim', data[1].channel.amplitude_range)
        # histogram
        hist_kwargs = self.parse_hist_kwargs()
        h, xedges, yedges = numpy.histogram2d(data[0], data[1], **hist_kwargs)
        h = numpy.ma.masked_where(h == 0, h)
        x, y = numpy.meshgrid(xedges, yedges, copy=False, sparse=True)
        # plot
        pcmesh_kwargs = self.parse_pcmesh_kwargs()
        ax.pcolormesh(x, y, h.T, **pcmesh_kwargs)
        # customise plot
        for key, val in self.pargs.iteritems():
            try:
                getattr(ax, 'set_%s' % key)(val)
            except AttributeError:
                if key == 'grid':
                    if val == 'off':
                        ax.grid('off')
                    elif val in ['both', 'major', 'minor']:
                        ax.grid('on', which=val)
                    else:
                        raise ValueError("Invalid ax.grid argument; "
                                         "valid options are: 'off', "
                                         "'both', 'major' or 'minor'")
                else:
                    setattr(ax, key, val)
        # add extra axes and finalise
        if not plot.colorbars:
            plot.add_colorbar(ax=ax, visible=False)
        return self.finalize(outputfile=outputfile)
Esempio n. 16
0
    def draw(self, outputfile=None):
        """Get data and generate the figure.
        """
        # get plot and axes
        (plot, axes) = self.init_plot()

        if self.state:
            self.pargs.setdefault(
                'suptitle', '[%s-%s, state: %s]' %
                (self.span[0], self.span[1], label_to_latex(str(self.state))))
        suptitle = self.pargs.pop('suptitle', None)
        if suptitle:
            plot.suptitle(suptitle, y=0.993, va='top')

        # extract histogram arguments
        histargs = self.parse_plot_kwargs()

        # get data
        data = []
        for channel in self.channels:
            if self.state and not self.all_data:
                valid = self.state.active
            else:
                valid = SegmentList([self.span])
            data.append(
                get_timeseries(channel, valid,
                               query=False).join(gap='ignore', pad=numpy.nan))
            # allow channel data to set parameters
            if hasattr(data[-1].channel, 'amplitude_range'):
                self.pargs.setdefault('xlim', data[-1].channel.amplitude_range)

        # get range
        if 'range' not in histargs[0]:
            l = axes[0].common_limits(data)
            for d in histargs:
                d['range'] = l

        # plot
        for ax, arr, pargs in zip(cycle(axes), data, histargs):
            if isinstance(pargs.get('weights', None), (float, int)):
                pargs['weights'] = numpy.ones_like(arr) * pargs['weights']
            try:
                ax.hist(arr, **pargs)
            except ValueError:  # empty dataset
                p2 = pargs.copy()
                p2.pop('weights')  # mpl errors on weights
                if p2.get('log', False) or self.pargs.get('logx', False):
                    p2['bottom'] = 1e-100  # default log 'bottom' is 1e-2
                ax.hist([], **p2)

        # customise plot
        legendargs = self.parse_legend_kwargs()
        for i, ax in enumerate(axes):
            for key, val in self.pargs.iteritems():
                if key == 'title' and i > 0:
                    continue
                if key == 'xlabel' and i < (len(axes) - 1):
                    continue
                if key == 'ylabel' and ((len(axes) % 2 and i != len(axes) // 2)
                                        or (len(axes) % 2 == 0 and i > 0)):
                    continue
                try:
                    getattr(ax, 'set_%s' % key)(val)
                except AttributeError:
                    setattr(ax, key, val)
            if len(self.channels) > 1:
                plot.add_legend(ax=ax, **legendargs)
        if len(axes) % 2 == 0 and axes[0].get_ylabel():
            label = axes[0].yaxis.label
            ax = axes[int(len(axes) // 2) - 1]
            ax.set_ylabel(label.get_text())
            ax.yaxis.label.set_position((0, -.2 / len(axes)))
            if len(axes) != 2:
                label.set_text('')

        # add extra axes and finalise
        if not plot.colorbars:
            for ax in axes:
                plot.add_colorbar(ax=ax, visible=False)
        return self.finalize(outputfile=outputfile)
Esempio n. 17
0
    def _draw(self):
        """Load all data, and generate this `SpectrumDataPlot`
        """
        plot = self.plot = FrequencySeriesPlot(
            figsize=self.pargs.pop('figsize', [12, 6]))
        ax = plot.gca()
        ax.grid(b=True, axis='both', which='both')

        if self.state:
            self.pargs.setdefault(
                'suptitle', '[%s-%s, state: %s]' %
                (self.span[0], self.span[1], label_to_latex(str(self.state))))
        suptitle = self.pargs.pop('suptitle', None)
        if suptitle:
            plot.suptitle(suptitle, y=0.993, va='top')

        # get spectrum format: 'amplitude' or 'power'
        sdform = self.pargs.pop('format')
        if sdform == 'rayleigh':
            method = 'rayleigh'
        else:
            method = None
        use_percentiles = str(
            self.pargs.pop('no-percentiles')).lower() == 'false'

        # parse plotting arguments
        plotargs = self.parse_plot_kwargs()
        legendargs = self.parse_legend_kwargs()
        use_legend = False

        # get reference arguments
        refs = []
        refkey = 'None'
        for key in sorted(self.pargs.keys()):
            if key == 'reference' or re.match('reference\d+\Z', key):
                refs.append(dict())
                refs[-1]['source'] = self.pargs.pop(key)
                refkey = key
            if re.match('%s[-_]' % refkey, key):
                refs[-1][key[len(refkey) + 1:]] = self.pargs.pop(key)

        # add data
        if self.type == 'coherence-spectrum':
            iterator = zip(self.channels[0::2], self.channels[1::2], plotargs)
        else:
            iterator = zip(self.channels, plotargs)

        for chantuple in iterator:
            channel = chantuple[0]
            channel2 = chantuple[1]
            pargs = chantuple[-1]

            if self.state and not self.all_data:
                valid = self.state
            else:
                valid = SegmentList([self.span])

            if self.type == 'coherence-spectrum':
                data = get_coherence_spectrum(
                    [str(channel), str(channel2)], valid, query=False)
            else:
                data = get_spectrum(str(channel),
                                    valid,
                                    query=False,
                                    format=sdform,
                                    method=method)

            # undo demodulation
            for spec in data:
                spec = undo_demodulation(spec, channel,
                                         self.pargs.get('xlim', None))

            # anticipate log problems
            if self.pargs['logx']:
                data = [s[1:] for s in data]
            if self.pargs['logy']:
                for sp in data:
                    sp.value[sp.value == 0] = 1e-100

            if 'label' in pargs:
                use_legend = True

            if use_percentiles:
                try:
                    ax.plot_frequencyseries_mmm(*data, **pargs)
                except AttributeError:  # old GWpy
                    ax.plot_spectrum_mmm(*data, **pargs)
            else:
                pargs.pop('alpha', None)
                try:
                    ax.plot_frequencyseries(data[0], **pargs)
                except AttributeError:  # old GWpy
                    ax.plot_spectrum(data[0], **pargs)

            # allow channel data to set parameters
            if getattr(channel, 'frequency_range', None) is not None:
                self.pargs.setdefault('xlim', channel.frequency_range)
                if isinstance(self.pargs['xlim'], Quantity):
                    self.pargs['xlim'] = self.pargs['xlim'].value
            if (sdform in ['amplitude', 'asd']
                    and hasattr(channel, 'asd_range')):
                self.pargs.setdefault('ylim', channel.asd_range)
            elif hasattr(channel, 'psd_range'):
                self.pargs.setdefault('ylim', channel.psd_range)

        # display references
        for i, ref in enumerate(refs):
            if 'source' in ref:
                source = ref.pop('source')
                try:
                    refspec = io.read_frequencyseries(source)
                except IOError as e:  # skip if file can't be read
                    warnings.warn('IOError: %s' % str(e))
                else:
                    ref.setdefault('zorder', -len(refs) + 1)
                    if 'filter' in ref:
                        refspec = refspec.filter(*ref.pop('filter'))
                    if 'scale' in ref:
                        refspec *= ref.pop('scale', 1)
                    if 'label' in ref:
                        use_legend = True
                    ax.plot(refspec, **ref)

        # customise
        hlines = list(self.pargs.pop('hline', []))
        for key, val in self.pargs.iteritems():
            try:
                getattr(ax, 'set_%s' % key)(val)
            except AttributeError:
                setattr(ax, key, val)

        # add horizontal lines to add
        if hlines:
            if not isinstance(hlines[-1], float):
                lineparams = hlines.pop(-1)
            else:
                lineparams = {'color': 'r', 'linestyle': '--'}
        for yval in hlines:
            try:
                yval = float(yval)
            except ValueError:
                continue
            else:
                ax.plot(ax.get_xlim(), [yval, yval], **lineparams)

        if use_legend or len(self.channels) > 1 or ax.legend_ is not None:
            plot.add_legend(ax=ax, **legendargs)
        if not plot.colorbars:
            plot.add_colorbar(ax=ax, visible=False)

        return self.finalize()
Esempio n. 18
0
    def _draw(self):
        """Load all data, and generate this `SpectrumDataPlot`
        """
        plot = self.plot = FrequencySeriesPlot(
            figsize=self.pargs.pop('figsize', [12, 6]))
        ax = plot.gca()
        ax.grid(b=True, axis='both', which='both')

        if self.state:
            self.pargs.setdefault(
                'suptitle', '[%s-%s, state: %s]' %
                (self.span[0], self.span[1], label_to_latex(str(self.state))))
        suptitle = self.pargs.pop('suptitle', None)
        if suptitle:
            plot.suptitle(suptitle, y=0.993, va='top')

        # get spectrum format: 'amplitude' or 'power'
        sdform = self.pargs.pop('format')

        # parse plotting arguments
        plotargs = self.parse_plot_kwargs()[0]

        # add data
        sumdata = []
        for i, channel in enumerate(self.channels):
            if self.state and not self.all_data:
                valid = self.state
            else:
                valid = SegmentList([self.span])

            data = get_spectrum(str(channel),
                                valid,
                                query=False,
                                format=sdform,
                                method=None)[0]
            if i:
                sumdata.append(data)
            else:
                target = data

        # assert all noise terms have the same resolution
        if any([x.dx != target.dx for x in sumdata]):
            raise RuntimeError("Noise components have different resolutions, "
                               "cannot construct sum of noises")
        # reshape noises if required
        n = target.size
        for i, d in enumerate(sumdata):
            if d.size < n:
                sumdata[i] = numpy.require(d, requirements=['O'])
                sumdata[i].resize((n, ))

        # calculate sum of noises
        sum_ = sumdata[0]**2
        for d in sumdata[1:]:
            sum_ += d**2
        sum_ **= (1 / 2.)

        # plot ratio of h(t) to sum of noises
        relative = sum_ / target
        ax.plot_frequencyseries(relative, **plotargs)

        # finalize plot
        self.apply_parameters(ax, **self.pargs)
        plot.add_colorbar(ax=ax, visible=False)

        return self.finalize()
Esempio n. 19
0
    def init_plots(self, plotdir=os.curdir):
        """Configure the default list of plots for this tab

        This method configures a veto-trigger glitchgram, histograms of
        before/after SNR and frequency/template duration,
        before and after glitchgrams, and a segment plot.

        This method is a mess, and should be re-written in a better way.
        """
        if self.intersection:
            label = 'Intersection'
        else:
            label = 'Union'

        self.layout = [
            1,
        ]
        before = get_channel(str(self.channel))
        for state in self.states:
            if self.channel:
                after = get_channel(
                    veto_tag(before, self.metaflag, mode='after'))
                vetoed = get_channel(
                    veto_tag(before, self.metaflag, mode='vetoed'))
                # -- configure trigger plots
                params = etg.get_etg_parameters(self.etg)
                glitchgramargs = {
                    'etg': self.etg,
                    'x': 'time',
                    'y': params['frequency'],
                    'logy': params.get('frequency-log', True),
                    'ylabel': get_column_label(params['frequency']),
                    'edgecolor': 'none',
                    'legend-scatterpoints': 1,
                    'legend-borderaxespad': 0,
                    'legend-bbox_to_anchor': (1.01, 1),
                    'legend-loc': 'upper left',
                }
                # plot before/after glitchgram
                self.plots.append(
                    get_plot('triggers')([after, vetoed],
                                         self.start,
                                         self.end,
                                         state=state,
                                         title='Impact of %s (%s)' %
                                         (label_to_latex(self.name), self.etg),
                                         outdir=plotdir,
                                         labels=['_', 'Vetoed'],
                                         colors=['lightblue', 'red'],
                                         **glitchgramargs))

                # plot histograms
                statistics = ['snr']
                if params['det'] != params['snr']:
                    statistics.append('det')
                self.layout.append(len(statistics) + 1)
                for column in statistics + ['frequency']:
                    self.plots.append(
                        get_plot('trigger-histogram')(
                            [before, after],
                            self.start,
                            self.end,
                            state=state,
                            column=params[column],
                            etg=self.etg,
                            outdir=plotdir,
                            title='Impact of %s (%s)' %
                            (label_to_latex(self.name), self.etg),
                            labels=['Before', 'After'],
                            xlabel=params.get('%s-label' % column,
                                              get_column_label(
                                                  params[column])),
                            color=['red', (0.2, 0.8, 0.2)],
                            logx=params.get('%s-log' % column, True),
                            logy=True,
                            histtype='stepfilled',
                            alpha=0.6,
                            weights=1 / float(abs(self.span)),
                            bins=100,
                            ybound=1 / float(abs(self.span)) * 0.5,
                            **{
                                'legend-borderaxespad': 0,
                                'legend-bbox_to_anchor': (1.01, 1),
                                'legend-loc': 'upper left'
                            }))

                # plot triggers before and after
                for stat in statistics:
                    column = params[stat]
                    glitchgramargs.update({
                        'color':
                        column,
                        'clim':
                        params.get('%s-limits' % stat, [3, 100]),
                        'logcolor':
                        params.get('%s-log' % stat, True),
                        'colorlabel':
                        params.get('%s-label' % stat,
                                   get_column_label(column)),
                    })
                    self.plots.append(
                        get_plot('triggers')([before],
                                             self.start,
                                             self.end,
                                             state=state,
                                             outdir=plotdir,
                                             **glitchgramargs))
                    self.plots.append(
                        get_plot('triggers')(
                            [after],
                            self.start,
                            self.end,
                            state=state,
                            title='After %s (%s)' %
                            (label_to_latex(self.name), self.etg),
                            outdir=plotdir,
                            **glitchgramargs))
                    self.layout.append(2)

            # -- configure segment plot
            segargs = {
                'state': state,
                'known': {
                    'alpha': 0.1,
                    'facecolor': 'lightgray'
                },
                'color': 'red',
            }
            if len(self.flags) == 1:
                sp = get_plot('segments')(self.flags,
                                          self.start,
                                          self.end,
                                          outdir=plotdir,
                                          labels=self.labels,
                                          **segargs)
            else:
                sp = get_plot('segments')(
                    [self.metaflag] + self.flags,
                    self.start,
                    self.end,
                    labels=([self.intersection and 'Intersection' or 'Union'] +
                            self.labels),
                    outdir=plotdir,
                    **segargs)
            self.plots.append(sp)
            self.layout.append(1)
Esempio n. 20
0
    def _draw(self):
        """Load all data, and generate this `SpectrumDataPlot`
        """
        plot = self.plot = FrequencySeriesPlot(
            figsize=self.pargs.pop('figsize', [12, 6]))
        ax = plot.gca()
        ax.grid(b=True, axis='both', which='both')

        if self.state:
            self.pargs.setdefault(
                'suptitle', '[%s-%s, state: %s]' %
                (self.span[0], self.span[1], label_to_latex(str(self.state))))
        suptitle = self.pargs.pop('suptitle', None)
        if suptitle:
            plot.suptitle(suptitle, y=0.993, va='top')

        # get spectrum format: 'amplitude' or 'power'
        sdform = self.pargs.pop('format')

        # parse plotting arguments
        plotargs = self.parse_plot_kwargs()
        legendargs = self.parse_legend_kwargs()

        # add data
        sumdata = []
        for i, (channel, pargs) in enumerate(zip(self.channels, plotargs)):
            if self.state and not self.all_data:
                valid = self.state
            else:
                valid = SegmentList([self.span])

            data = get_spectrum(str(channel),
                                valid,
                                query=False,
                                format=sdform,
                                method=None)[0]
            if i:
                sumdata.append(data)

            # anticipate log problems
            if self.pargs['logx']:
                data = data[1:]
            if self.pargs['logy']:
                data.value[data.value == 0] = 1e-100

            pargs.setdefault('zorder', -i)
            ax.plot_frequencyseries(data, **pargs)

        # assert all noise terms have the same resolution
        if any([x.dx != sumdata[0].dx for x in sumdata]):
            raise RuntimeError("Noise components have different resolutions, "
                               "cannot construct sum of noises")
        # reshape noises if required
        n = max(x.size for x in sumdata)
        for i, d in enumerate(sumdata):
            if d.size < n:
                sumdata[i] = numpy.require(d, requirements=['O'])
                sumdata[i].resize((n, ))

        # plot sum of noises
        sumargs = self.parse_sum_params()
        sum_ = sumdata[0]**2
        for d in sumdata[1:]:
            sum_ += d**2
        sum_ **= (1 / 2.)
        ax.plot_frequencyseries(sum_, zorder=1, **sumargs)
        ax.lines.insert(1, ax.lines.pop(-1))

        self.apply_parameters(ax, **self.pargs)
        plot.add_legend(ax=ax, **legendargs)
        plot.add_colorbar(ax=ax, visible=False)

        return self.finalize()
Esempio n. 21
0
    def gen_plot(self, args):
        """Generate the plot from time series and arguments"""
        self.max_size = 16384. * 6400.  # that works on my mac
        self.yscale_factor = 1.0

        from gwpy.plotter.tex import label_to_latex
        from numpy import min as npmin
        from numpy import max as npmax

        if self.timeseries[0].size <= self.max_size:
            self.plot = self.timeseries[0].plot()
        else:
            self.plot = self.timeseries[0].plot(linestyle='None', marker='.')
        self.ymin = self.timeseries[0].min().value
        self.ymax = self.timeseries[0].max().value
        self.xmin = self.timeseries[0].times.value.min()
        self.xmax = self.timeseries[0].times.value.max()

        if len(self.timeseries) > 1:
            for idx in range(1, len(self.timeseries)):
                chname = self.timeseries[idx].channel.name
                lbl = label_to_latex(chname)
                if self.timeseries[idx].size <= self.max_size:
                    self.plot.add_timeseries(self.timeseries[idx], label=lbl)
                else:
                    self.plot.add_timeseries(self.timeseries[idx],
                                             label=lbl,
                                             linestyle='None',
                                             marker='.')
                self.ymin = min(self.ymin, self.timeseries[idx].min().value)
                self.ymax = max(self.ymax, self.timeseries[idx].max().value)
                self.xmin = min(self.xmin,
                                self.timeseries[idx].times.value.min())
                self.xmax = max(self.xmax,
                                self.timeseries[idx].times.value.max())
        # if they chose to set the range of the x-axis find the range of y
        strt = self.xmin
        stop = self.xmax
        # a bit weird but global ymax will be >= any value in
        # the range same for ymin
        new_ymin = self.ymax
        new_ymax = self.ymin

        if args.xmin:
            strt = float(args.xmin)
        if args.xmax:
            stop = float(args.xmax)
        if strt != self.xmin or stop != self.xmax:
            for idx in range(0, len(self.timeseries)):
                x0 = self.timeseries[idx].x0.value
                dt = self.timeseries[idx].dt.value
                if strt < 1e8:
                    strt += x0
                if stop < 1e8:
                    stop += x0
                b = int(max(0, (strt - x0) / dt))

                e = int(min(self.xmax, (stop - x0) / dt))

                if e >= self.timeseries[idx].size:
                    e = self.timeseries[idx].size - 1
                new_ymin = min(new_ymin,
                               npmin(self.timeseries[idx].value[b:e]))
                new_ymax = max(new_ymax,
                               npmax(self.timeseries[idx].value[b:e]))
            self.ymin = new_ymin
            self.ymax = new_ymax
        if self.yscale_factor > 1:
            self.log(2, ('Scaling y-limits, original: %f, %f)' %
                         (self.ymin, self.ymax)))
            yrange = self.ymax - self.ymin
            mid = (self.ymax + self.ymin) / 2.
            self.ymax = mid + yrange / (2 * self.yscale_factor)
            self.ymin = mid - yrange / (2 * self.yscale_factor)
            self.log(2, ('Scaling y-limits, new: %f, %f)' %
                         (self.ymin, self.ymax)))
        return
Esempio n. 22
0
    def draw(self, outputfile=None):
        (plot, axes) = self.init_plot(plot=Plot)
        ax = axes[0]

        if self.state:
            self.pargs.setdefault(
                'suptitle',
                '[%s-%s, state: %s]' % (self.span[0], self.span[1],
                                        label_to_latex(str(self.state))))
        else:
            self.pargs.setdefault(
                'suptitle', '[%s-%s]' % (self.span[0], self.span[1]))
        suptitle = self.pargs.pop('suptitle', None)
        if suptitle:
            plot.suptitle(suptitle, y=0.993, va='top')

        scale = self.pargs.pop('scale', 'percent')
        if scale == 'percent':
            self.pargs.setdefault('ylim', (0, 100))
        elif isinstance(scale, (int, float)):
            self.pargs.setdefault('ylim', (0, abs(self.span) / scale))
        try:
            self.pargs.setdefault('ylabel', 'Livetime [%s]'
                                  % self.SCALE_UNIT[scale])
        except KeyError:
            self.pargs.setdefault('ylabel', 'Livetime')

        # extract plotting arguments
        sort = self.pargs.pop('sorted', False)
        plotargs = self.parse_plot_kwargs()

        # get segments
        data = []
        labels = plotargs.pop('labels', self.flags)
        for flag in self.flags:
            if self.state and not self.all_data:
                valid = self.state.active
            else:
                valid = SegmentList([self.span])
            segs = get_segments(flag, validity=valid, query=False,
                                padding=self.padding).coalesce()
            livetime = float(abs(segs.active))
            if scale == 'percent':
                try:
                    data.append(100 * livetime / float(abs(segs.known)))
                except ZeroDivisionError:
                    data.append(0)
            elif isinstance(scale, (float, int)):
                data.append(livetime / scale)

        if sort:
            data, labels = zip(*sorted(
                zip(data, labels), key=lambda x: x[0], reverse=True))

        # make bar chart
        width = plotargs.pop('width', .8)
        x = numpy.arange(len(data)) - width/2.
        patches = ax.bar(x, data, width=width, **plotargs)[0]

        # set labels
        ax.set_xticks(range(len(data)))
        ax.set_xticklabels(labels, rotation=30,
                           rotation_mode='anchor', ha='right', fontsize=13)
        ax.tick_params(axis='x', pad=2)
        ax.xaxis.labelpad = 2
        ax.xaxis.grid(False)
        self.pargs.setdefault('xlim', (-.5, len(data)-.5))

        # customise plot
        for key, val in self.pargs.iteritems():
            try:
                getattr(ax, 'set_%s' % key)(val)
            except AttributeError:
                setattr(ax, key, val)

        # add bit mask axes and finalise
        self.pargs['xlim'] = None
        return self.finalize(outputfile=outputfile, transparent="True",
                             pad_inches=0)
Esempio n. 23
0
    def annotate_save_plot(self, args):
        """After the derived class generated a plot
        object finish the process
        """
        from astropy.time import Time
        from gwpy.plotter.tex import label_to_latex
        import matplotlib

        self.ax = self.plot.gca()
        # set up axes
        self.setup_xaxis(args)
        self.setup_yaxis(args)
        self.setup_iaxis(args)

        if self.is_image():
            if args.nocolorbar:
                self.plot.add_colorbar(visible=False)
            else:
                self.plot.add_colorbar(label=self.get_color_label())
        else:
            self.plot.add_colorbar(visible=False)

        # image plots don't have legends
        if not self.is_image():
            leg = self.ax.legend(prop={'size': 10})
            # if only one series is plotted hide legend
            if self.n_datasets == 1 and leg:
                try:
                    leg.remove()
                except NotImplementedError:
                    leg.set_visible(False)

        # add titles
        title = ''
        if args.title:
            for t in args.title:
                if len(title) > 0:
                    title += "\n"
                title += t
        # info on the processing
        start = self.start_list[0]
        startGPS = Time(start, format='gps', scale='utc')
        timeStr = "%s - %10d (%ds)" % (startGPS.iso, start, self.dur)

        # list the different sample rates available in all time series
        fs_set = set()

        for idx in range(0, len(self.timeseries)):
            fs = self.timeseries[idx].sample_rate
            fs_set.add(fs)

        fs_str = ''
        for fs in fs_set:
            if len(fs_str) > 0:
                fs_str += ', '
            fs_str += '(%s)' % fs

        if self.is_freq_plot:
            spec = r'%s, Fs=%s, secpfft=%.1f (bw=%.3f), overlap=%.2f' %  \
                    (timeStr, fs_str, self.secpfft, 1/self.secpfft,
                     self.overlap)
        else:
            xdur = self.xmax - self.xmin
            spec = r'Fs=%s, duration: %.1f' % (fs_str, xdur)
        spec += ", " + self.filter
        if len(title) > 0:
            title += "\n"
        if self.title2:
            title += self.title2
        else:
            title += spec

        title = label_to_latex(title)
        self.ax.set_title(title, fontsize=12)
        self.log(3, ('Title is: %s' % title))

        if args.xlabel:
            xlabel = label_to_latex(args.xlabel)
        else:
            xlabel = self.get_xlabel()
        if xlabel:
            self.ax.set_xlabel(xlabel)
            self.log(3, ('X-axis label is: %s' % xlabel))

        all_units = set()
        for ts in self.timeseries:
            un = str(ts.unit)
            all_units.add(un)

        if len(all_units) == 1:
            self.units = label_to_latex(all_units.pop())
        elif len(all_units) > 1:
            self.units = 'Multiple units'
        else:
            self.units = 'undef'

        if args.ylabel:
            ylabel = label_to_latex(args.ylabel)
        else:
            ylabel = self.get_ylabel(args)

        if ylabel:
            self.ax.set_ylabel(ylabel)
            self.log(3, ('Y-axis label is: %s' % ylabel))

        if not args.nogrid:
            self.ax.grid(b=True, which='major', color='k', linestyle='solid')
            self.ax.grid(b=True,
                         which='minor',
                         color='0.06',
                         linestyle='dotted')

        # info on the channel
        if args.suptitle:
            sup_title = args.suptitle
        else:
            sup_title = self.get_sup_title()
        sup_title = label_to_latex(sup_title)
        self.plot.suptitle(sup_title, fontsize=18)

        self.log(3, ('Super title is: %s' % sup_title))
        self.show_plot_info()

        # change the label for GPS time so Josh is happy
        if self.ax.get_xscale() == 'auto-gps':
            xscale = self.ax.xaxis._scale
            epoch = xscale.get_epoch()
            unit = xscale.get_unit_name()
            utc = re.sub(r'\.0+', '',
                         Time(epoch, format='gps', scale='utc').iso)
            self.ax.set_xlabel('Time (%s) from %s (%s)' % (unit, utc, epoch))
            self.ax.set_xscale(unit, epoch=epoch)

        # if they specified an output file write it
        # save the figure. Note type depends on extension of
        # output filename (png, jpg, pdf)
        if args.out:
            out_file = args.out
        elif 'outdir' in args:
            out_file = args.outdir + '/gwpy.png'
        else:
            out_file = "./gwpy.png"

        self.log(3, ('xinch: %.2f, yinch: %.2f, dpi: %d' %
                     (self.xinch, self.yinch, self.dpi)))

        self.fig = matplotlib.pyplot.gcf()
        self.fig.set_size_inches(self.xinch, self.yinch)
        self.plot.savefig(out_file,
                          edgecolor='white',
                          figsize=[self.xinch, self.yinch],
                          dpi=self.dpi,
                          bbox_inches='tight')
        self.log(3, ('wrote %s' % out_file))
Esempio n. 24
0
    def draw(self, outputfile=None):
        (plot, axes) = self.init_plot(plot=Plot)
        ax = axes[0]

        # get labels
        #flags = map(lambda f: str(f).replace('_', r'\_'), self.flags)
        #labels = self.pargs.pop('labels', self.pargs.pop('label', flags))
        #labels = map(lambda s: re_quote.sub('', str(s).strip('\n ')), labels)

        # extract plotting arguments
        future = self.pargs.pop('include_future', False)
        legendargs = self.parse_legend_kwargs()
        wedgeargs = self.parse_wedge_kwargs()
        plotargs = self.parse_plot_kwargs()

        # use state to generate suptitle with GPS span
        if self.state:
            self.pargs.setdefault(
                'suptitle',
                '[%s-%s, state: %s]' % (self.span[0], self.span[1],
                                        label_to_latex(str(self.state))))
        else:
            self.pargs.setdefault(
                'suptitle', '[%s-%s]' % (self.span[0], self.span[1]))

        # get segments
        data = []
        for flag in self.flags:
            if self.state and not self.all_data:
                valid = self.state.active
            else:
                valid = SegmentList([self.span])
            segs = get_segments(flag, validity=valid, query=False,
                                padding=self.padding).coalesce()
            data.append(float(abs(segs.active)))
        if future:
            total = sum(data)
            alltime = abs(self.span)
            data.append(alltime-total)
            if 'labels' in plotargs:
                plotargs['labels'] = list(plotargs['labels']) + [' ']
            if 'colors' in plotargs:
                plotargs['colors'] = list(plotargs['colors']) + ['white']

        # make pie
        labels = plotargs.pop('labels')
        patches = ax.pie(data, **plotargs)[0]
        ax.axis('equal')

        # set wedge params
        for wedge in patches:
            for key, val in wedgeargs.iteritems():
                getattr(wedge, 'set_%s' % key)(val)

        # make legend
        legendargs['title'] = self.pargs.pop('title', None)
        legth = legendargs.pop('threshold', 0)
        legsort = legendargs.pop('sorted', False)
        tot = float(sum(data))
        pclabels = []
        for d, label in zip(data, labels):
            if not label or label == ' ':
                pclabels.append(label)
            else:
                try:
                    pc = d/tot * 100
                except ZeroDivisionError:
                    pc = 0.0
                pclabels.append(label_to_latex(
                    '%s [%1.1f%%]' % (label, pc)).replace(r'\\', '\\'))

        # add time to top
        suptitle = self.pargs.pop('suptitle', None)
        if suptitle:
            extra = Rectangle((0,0), 1, 1, fc='w', fill=False, ec='none',
                              linewidth=0)
        # sort entries
        if legsort:
            patches, pclabels, data = map(list, zip(*sorted(
                 zip(patches, pclabels, data),
                 key=lambda x: x[2],
                 reverse=True)))
        # and restrict to the given threshold
        if legth:
            try:
                patches, pclabels, data = map(list, zip(*[
                    x for x in zip(patches, pclabels, data) if x[2] >= legth]))
            except ValueError:
                pass

        if suptitle:
            leg = ax.legend([extra]+patches, [suptitle]+pclabels, **legendargs)
            t = leg.get_texts()[0]
            t.set_fontproperties(t.get_fontproperties().copy())
            t.set_size(min(12, t.get_size()))
        else:
            leg = ax.legend(patches, pclabels, **legendargs)
        legt = leg.get_title()
        legt.set_fontsize(max(22, legendargs.get('fontsize', 22)+4))
        legt.set_ha('left')

        # customise plot
        for key, val in self.pargs.iteritems():
            try:
                getattr(ax, 'set_%s' % key)(val)
            except AttributeError:
                setattr(ax, key, val)

        # copy title and move axes
        if ax.get_title():
            title = plot.suptitle(ax.get_title())
            title.update_from(ax.title)
            title.set_y(title._y + 0.05)
            ax.set_title('')
        axpos = ax.get_position()
        offset = -.2
        ax.set_position([axpos.x0+offset, .1, axpos.width, .8])

        # add bit mask axes and finalise
        self.pargs['xlim'] = None
        return self.finalize(outputfile=outputfile, transparent="True",
                             pad_inches=0)
Esempio n. 25
0
    def _process(self):
        """Load all data, and generate this `SpectrumDataPlot`
        """
        plot = self.plot = SpectrumPlot(
            figsize=self.pargs.pop('figsize', [12, 6]))
        ax = plot.gca()
        ax.grid(b=True, axis='both', which='both')

        if self.state:
            self.pargs.setdefault(
                'suptitle',
                '[%s-%s, state: %s]' % (self.span[0], self.span[1],
                                        label_to_latex(str(self.state))))
        suptitle = self.pargs.pop('suptitle', None)
        if suptitle:
            plot.suptitle(suptitle, y=0.993, va='top')

        # get spectrum format: 'amplitude' or 'power'
        sdform = self.pargs.pop('format')
        use_percentiles = str(
            self.pargs.pop('no_percentiles')).lower() == 'false'

        # parse plotting arguments
        plotargs = self.parse_plot_kwargs()
        legendargs = self.parse_legend_kwargs()

        # get reference arguments
        refs = []
        refkey = 'None'
        for key in sorted(self.pargs.keys()):
            if key == 'reference' or re.match('reference\d+\Z', key):
                refs.append(dict())
                refs[-1]['source'] = self.pargs.pop(key)
                refkey = key
            if re.match('%s[-_]' % refkey, key):
                refs[-1][key[len(refkey)+1:]] = self.pargs.pop(key)

        # add data
        for channel, pargs in zip(self.channels, plotargs):
            if self.state and not self.all_data:
                valid = self.state
            else:
                valid = SegmentList([self.span])
            data = get_spectrum(str(channel), valid, query=False,
                                format=sdform)

            # anticipate log problems
            if self.pargs['logx']:
                data = [s[1:] for s in data]
            if self.pargs['logy']:
                for sp in data:
                    sp.value[sp.value == 0] = 1e-100

            if use_percentiles:
                ax.plot_spectrum_mmm(*data, **pargs)
            else:
                pargs.pop('alpha', None)
                ax.plot_spectrum(data[0], **pargs)

            # allow channel data to set parameters
            if getattr(channel, 'frequency_range', None) is not None:
                self.pargs.setdefault('xlim', channel.frequency_range)
                if isinstance(self.pargs['xlim'], Quantity):
                    self.pargs['xlim'] = self.pargs['xlim'].value
            if (sdform in ['amplitude', 'asd'] and
                    hasattr(channel, 'asd_range')):
                self.pargs.setdefault('ylim', channel.asd_range)
            elif hasattr(channel, 'psd_range'):
                self.pargs.setdefault('ylim', channel.psd_range)

        # display references
        for i, ref in enumerate(refs):
            if 'source' in ref:
                source = ref.pop('source')
                try:
                    refspec = Spectrum.read(source)
                except IOError as e:
                    warnings.warn('IOError: %s' % str(e))
                except Exception as e:
                    # hack for old versions of GWpy
                    # TODO: remove me when GWSumm requires GWpy > 0.1
                    if 'Format could not be identified' in str(e):
                        refspec = Spectrum.read(source, format='dat')
                    else:
                        raise
                else:
                    ref.setdefault('zorder', -len(refs) + 1)
                    if 'filter' in ref:
                        refspec = refspec.filter(*ref.pop('filter'))
                    if 'scale' in ref:
                        refspec *= ref.pop('scale', 1)
                    ax.plot(refspec, **ref)

        # customise
        hlines = list(self.pargs.pop('hline', []))
        for key, val in self.pargs.iteritems():
            try:
                getattr(ax, 'set_%s' % key)(val)
            except AttributeError:
                setattr(ax, key, val)

        # add horizontal lines to add
        if hlines:
            if not isinstance(hlines[-1], float):
                lineparams = hlines.pop(-1)
            else:
                lineparams = {'color':'r', 'linestyle': '--'}
        for yval in hlines:
            try:
                yval = float(yval)
            except ValueError:
                continue
            else:
                ax.plot(ax.get_xlim(), [yval, yval], **lineparams)

        if len(self.channels) > 1 or ax.legend_ is not None:
            plot.add_legend(ax=ax, **legendargs)
        if not plot.colorbars:
            plot.add_colorbar(ax=ax, visible=False)

        return self.finalize()
Esempio n. 26
0
 def test_label_to_latex(self):
     self.assertEqual(label_to_latex('Test'), 'Test')
     self.assertEqual(label_to_latex('Test_with_underscore'),
                      r'Test\_with\_underscore')
Esempio n. 27
0
    def _process(self):
        """Load all data, and generate this `SpectrumDataPlot`
        """
        plot = self.plot = SpectrumPlot(
            figsize=self.pargs.pop('figsize', [12, 6]))
        ax = plot.gca()
        ax.grid(b=True, axis='both', which='both')

        if self.state:
            self.pargs.setdefault(
                'suptitle', '[%s-%s, state: %s]' %
                (self.span[0], self.span[1], label_to_latex(str(self.state))))
        suptitle = self.pargs.pop('suptitle', None)
        if suptitle:
            plot.suptitle(suptitle, y=0.993, va='top')

        # get spectrum format: 'amplitude' or 'power'
        sdform = self.pargs.pop('format')
        use_percentiles = str(
            self.pargs.pop('no_percentiles')).lower() == 'false'

        # parse plotting arguments
        plotargs = self.parse_plot_kwargs()
        legendargs = self.parse_legend_kwargs()

        # get reference arguments
        refs = []
        refkey = 'None'
        for key in sorted(self.pargs.keys()):
            if key == 'reference' or re.match('reference\d+\Z', key):
                refs.append(dict())
                refs[-1]['source'] = self.pargs.pop(key)
                refkey = key
            if re.match('%s[-_]' % refkey, key):
                refs[-1][key[len(refkey) + 1:]] = self.pargs.pop(key)

        # add data
        for channel, pargs in zip(self.channels, plotargs):
            if self.state and not self.all_data:
                valid = self.state
            else:
                valid = SegmentList([self.span])
            data = get_spectrum(str(channel),
                                valid,
                                query=False,
                                format=sdform)

            # anticipate log problems
            if self.pargs['logx']:
                data = [s[1:] for s in data]
            if self.pargs['logy']:
                for sp in data:
                    sp.value[sp.value == 0] = 1e-100

            if use_percentiles:
                ax.plot_spectrum_mmm(*data, **pargs)
            else:
                pargs.pop('alpha', None)
                ax.plot_spectrum(data[0], **pargs)

            # allow channel data to set parameters
            if getattr(channel, 'frequency_range', None) is not None:
                self.pargs.setdefault('xlim', channel.frequency_range)
                if isinstance(self.pargs['xlim'], Quantity):
                    self.pargs['xlim'] = self.pargs['xlim'].value
            if (sdform in ['amplitude', 'asd']
                    and hasattr(channel, 'asd_range')):
                self.pargs.setdefault('ylim', channel.asd_range)
            elif hasattr(channel, 'psd_range'):
                self.pargs.setdefault('ylim', channel.psd_range)

        # display references
        for i, ref in enumerate(refs):
            if 'source' in ref:
                source = ref.pop('source')
                try:
                    refspec = Spectrum.read(source)
                except IOError as e:
                    warnings.warn('IOError: %s' % str(e))
                except Exception as e:
                    # hack for old versions of GWpy
                    # TODO: remove me when GWSumm requires GWpy > 0.1
                    if 'Format could not be identified' in str(e):
                        refspec = Spectrum.read(source, format='dat')
                    else:
                        raise
                else:
                    ref.setdefault('zorder', -len(refs) + 1)
                    if 'filter' in ref:
                        refspec = refspec.filter(*ref.pop('filter'))
                    if 'scale' in ref:
                        refspec *= ref.pop('scale', 1)
                    ax.plot(refspec, **ref)

        # customise
        hlines = list(self.pargs.pop('hline', []))
        for key, val in self.pargs.iteritems():
            try:
                getattr(ax, 'set_%s' % key)(val)
            except AttributeError:
                setattr(ax, key, val)

        # add horizontal lines to add
        if hlines:
            if not isinstance(hlines[-1], float):
                lineparams = hlines.pop(-1)
            else:
                lineparams = {'color': 'r', 'linestyle': '--'}
        for yval in hlines:
            try:
                yval = float(yval)
            except ValueError:
                continue
            else:
                ax.plot(ax.get_xlim(), [yval, yval], **lineparams)

        if len(self.channels) > 1 or ax.legend_ is not None:
            plot.add_legend(ax=ax, **legendargs)
        if not plot.colorbars:
            plot.add_colorbar(ax=ax, visible=False)

        return self.finalize()
Esempio n. 28
0
 def test_label_to_latex(self):
     self.assertEqual(label_to_latex('Test'), 'Test')
     self.assertEqual(label_to_latex('Test_with_underscore'),
                      r'Test\_with\_underscore')
Esempio n. 29
0
    def process(self, outputfile=None):
        """Read in all necessary data, and generate the figure.
        """
        (plot, axes) = self.init_plot()
        ax = axes[0]

        plotargs = self.parse_plot_kwargs()
        legendargs = self.parse_legend_kwargs()

        # add data
        channels, groups = zip(*self.get_channel_groups())
        for clist, pargs in zip(groups, plotargs):
            # pad data request to over-fill plots (no gaps at the end)
            if self.state and not self.all_data:
                valid = self.state.active
            elif clist[0].sample_rate:
                valid = SegmentList([self.span.protract(
                    1/clist[0].sample_rate.value)])
            else:
                valid = SegmentList([self.span])
            # get data
            data = [get_timeseries(c, valid, query=False)
                    for c in clist]
            if len(clist) > 1:
                data = [tsl.join(gap='pad', pad=numpy.nan) for tsl in data]
            flatdata = [ts for tsl in data for ts in tsl]
            # validate parameters
            for ts in flatdata:
                # double-check empty
                if (hasattr(ts, 'metadata') and
                        not 'x0' in ts.metadata) or not ts.x0:
                    ts.epoch = self.start
                # double-check log scales
                if self.pargs.get('logy', False):
                    ts.value[ts.value == 0] = 1e-100
            # set label
            try:
                label = pargs.pop('label')
            except KeyError:
                try:
                    label = label_to_latex(flatdata[0].name)
                except IndexError:
                    label = clist[0]
                else:
                    if self.fileformat == 'svg' and not label.startswith(
                            label_to_latex(
                            str(flatdata[0].channel)).split('.')[0]):
                        label += ' [%s]' % (
                            label_to_latex(str(flatdata[0].channel)))
            # plot groups or single TimeSeries
            if len(clist) > 1:
                ax.plot_timeseries_mmm(*data, label=label, **pargs)
            elif len(flatdata) == 0:
                ax.plot_timeseries(
                    data[0].EntryClass([], epoch=self.start, unit='s',
                                       name=label), label=label, **pargs)
            else:
                for ts in data[0]:
                    line = ax.plot_timeseries(ts, label=label, **pargs)[0]
                    label = None
                    pargs['color'] = line.get_color()

            # allow channel data to set parameters
            if len(flatdata):
                chan = get_channel(str(flatdata[0].channel))
            else:
                chan = get_channel(clist[0])
            if getattr(chan, 'amplitude_range', None) is not None:
                self.pargs.setdefault('ylim', chan.amplitude_range)

        # add horizontal lines to add
        for yval in self.pargs.get('hline', []):
            try:
                yval = float(yval)
            except ValueError:
                continue
            else:
                ax.plot([self.start, self.end], [yval, yval],
                        linestyle='--', color='red')

        # customise plot
        for key, val in self.pargs.iteritems():
            try:
                getattr(ax, 'set_%s' % key)(val)
            except AttributeError:
                setattr(ax, key, val)
        if (len(channels) > 1 or plotargs[0].get('label', None) in
                [re.sub(r'(_|\\_)', r'\_', channels[0]), None]):
            plot.add_legend(ax=ax, **legendargs)

        # add extra axes and finalise
        if not plot.colorbars:
            plot.add_colorbar(ax=ax, visible=False)
        if self.state:
            self.add_state_segments(ax)
        return self.finalize(outputfile=outputfile)