Example #1
0
def scatterplot(data, name="", xlabel="", ylabel="", size= 3):
  """
  Creates a scatter plot from x,y data.

  *data* is a list of (x,y) tuples.
  """
    
  xAxis = NumberAxis(xlabel)   
  xAxis.setAutoRangeIncludesZero(False)   
  yAxis = NumberAxis(ylabel)   
  yAxis.setAutoRangeIncludesZero(False)   
   
  series = XYSeries("Values");     
  for (i,j) in data:         
    series.add(i, j)

  dataset = XYSeriesCollection()
  dataset.addSeries(series);
  chart = ChartFactory.createScatterPlot(name, xlabel, ylabel, dataset, 
    PlotOrientation.VERTICAL, True, True, False)    
  plot = chart.getPlot()
  plot.getRenderer().setSeriesShape(0, 
    ShapeUtilities.createRegularCross(size,size));                  
    
  return Chart(chart)
    
Example #2
0
def scatterplot(data, name="", xlabel="", ylabel="", size= 3):
  """
  Creates a scatter plot from x,y data.

  *data* is a list of (x,y) tuples.
  """
    
  xAxis = NumberAxis(xlabel)   
  xAxis.setAutoRangeIncludesZero(False)   
  yAxis = NumberAxis(ylabel)   
  yAxis.setAutoRangeIncludesZero(False)   
   
  series = XYSeries("Values");     
  for (i,j) in data:         
    series.add(i, j)

  dataset = XYSeriesCollection()
  dataset.addSeries(series);
  chart = ChartFactory.createScatterPlot(name, xlabel, ylabel, dataset, 
    PlotOrientation.VERTICAL, True, True, False)    
  plot = chart.getPlot()
  plot.getRenderer().setSeriesShape(0, 
    ShapeUtilities.createRegularCross(size,size));                  
    
  return Chart(chart)
    def __init__(self, automations):

        # Create the frame
        frame = JFrame("Automation Viewer")
        frame.setSize(500, 300)
        frame.setLayout(BorderLayout())

        series = AutomationSeries
        # Finalize the window
        frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE)
        frame.setVisible(True)

        # Create an XY dataset
        dataset = XYSeriesCollection()

        for autoname in automations:

            automation = ModbusPal.getAutomation(autoname)
            series = AutomationSeries(automation)
            dataset.addSeries(series)
            frame.addWindowListener(series)

        # Create chart
        chart = ChartFactory.createXYLineChart("Automation Viewer",
                                               "Time (seconds)", "Value",
                                               dataset,
                                               PlotOrientation.VERTICAL,
                                               Boolean.TRUE, Boolean.TRUE,
                                               Boolean.FALSE)
        panel = ChartPanel(chart)

        # Add chart to panel
        frame.add(panel, BorderLayout.CENTER)
Example #4
0
def curve(data, name="", smooth=True, trid=True):
    """
  Creates a curve based on a list of (x,y) tuples.
    
  Setting *smooth* to ``True`` results in a spline renderer renderer is used.

  Setting *trid* to ``True`` results in a 3D plot. In this case the ``smooth``
  argument is ignored.
  """

    dataset = XYSeriesCollection()
    xy = XYSeries(name)
    for d in data:
        xy.add(d[0], d[1])
    dataset.addSeries(xy)
    chart = ChartFactory.createXYLineChart(None, None, None, dataset,
                                           PlotOrientation.VERTICAL, True,
                                           True, False)

    if smooth:
        chart.getXYPlot().setRenderer(XYSplineRenderer())
    if trid:
        chart.getXYPlot().setRenderer(XYLine3DRenderer())

    return Chart(chart)
Example #5
0
 def createChart(self):
     dataset = XYSeriesCollection()
     for s in self.allSeries:
         dataset.addSeries(s)
     # title is None
     chart = ChartFactory.createScatterPlot( \
                 None, xlabel(), ylabel(), \
                 dataset, PlotOrientation.VERTICAL, \
                 True, True, False)
     return chart
Example #6
0
 def createChart(self):
     dataset = XYSeriesCollection()
     for s in self.allSeries:
         dataset.addSeries(s)
     # title is None
     chart = ChartFactory.createScatterPlot( \
                 None, xlabel(), ylabel(), \
                 dataset, PlotOrientation.VERTICAL, \
                 True, True, False)
     return chart
Example #7
0
def plot(title, x_label, y_label, *curves):
    dataset = XYSeriesCollection()
    for legend, curve in curves:
        series = XYSeries(legend)
        for x, y in curve:
            series.add(x, y)
        dataset.addSeries(series)
    chart = ChartFactory.createXYLineChart(title, x_label, y_label, dataset,
                                           PlotOrientation.VERTICAL, True,
                                           True, False)
    frame = ChartFrame(title, chart)
    frame.setVisible(True)
    frame.setSize(400, 300)
Example #8
0
 def getDataSets(self, txName):
     """
     For a given txNum name, return a set of JFreeChart datasets
     with data on tx/sec, response times, and bandwidth.
     """
     logger.debug("getting data sets for " + txName)
     if self._txNameDatasets == None:
         logger.warn("Building data sets.")
         txSecDataset = None
         # build 'em
         self._txNameDatasets = {}
         txNums = self._summaryData.getTxNumNameMap().keys()
         for txNum in txNums:
             logger.debug("DEBUG: building DS for " + txNum)
             dataSetGroup = {}
             txSecDataset = XYSeriesCollection()  # not returning a new object
             bandwidthDataSet = XYSeriesCollection()
             simpleResponseTimeDataset = XYSeriesCollection()
             responseTimeDataset = DefaultTableXYDataset()
             txSecPassSeries = XYSeries("passed")
             txSecFailSeries = XYSeries("failed")
             responseTimeSeries = XYSeries("seconds")
             finishTimeSeries = XYSeries("complete", True, False)
             resolveHostSeries = XYSeries("resolveHost", True, False)
             connectSeries = XYSeries("connect", True, False)
             firstByteSeries = XYSeries("firstByte", True, False)
             bandwidthSeries = XYSeries("KB/sec")
             for bucket in self.bucketList:
                 txSecPass = bucket.getTxSecPass(txNum)
                 txSecPassSeries.add(bucket.getStartTime() / 1000.0, txSecPass)
                 txSecFail = bucket.getTxSecFail(txNum)
                 txSecFailSeries.add(bucket.getStartTime() / 1000.0, txSecFail)
                 responseTimeSeries.add(bucket.getStartTime() / 1000.0, bucket.getMeanResponseTime(txNum))
                 if ga.constants.VORPAL.getPlugin("analyzer").isHTTP():
                     bandwidthSeries.add(bucket.getStartTime() / 1000.0, bucket.getMeanThroughputKBSec(txNum))
                     finishTimeSeries.add(bucket.getStartTime() / 1000.0, bucket.getMeanFinishTime(txNum))
                     resolveHostSeries.add(bucket.getStartTime() / 1000.0, bucket.getMeanResolveHostTime(txNum))
                     connectSeries.add(bucket.getStartTime() / 1000.0, bucket.getMeanConnectTime(txNum))
                     firstByteSeries.add(bucket.getStartTime() / 1000.0, bucket.getMeanFirstByteTime(txNum))
             txSecDataset.addSeries(txSecPassSeries)
             txSecDataset.addSeries(txSecFailSeries)
             responseTimeDataset.addSeries(resolveHostSeries)
             responseTimeDataset.addSeries(connectSeries)
             responseTimeDataset.addSeries(firstByteSeries)
             responseTimeDataset.addSeries(finishTimeSeries)
             simpleResponseTimeDataset.addSeries(responseTimeSeries)
             bandwidthDataSet.addSeries(bandwidthSeries)
             dataSetGroup[TX_SEC_KEY] = txSecDataset
             dataSetGroup[FULL_RESPONSE_TIME_KEY] = responseTimeDataset
             dataSetGroup[THROUGHPUT_KEY] = bandwidthDataSet
             dataSetGroup[SIMPLE_RESPONSE_TIME_KEY] = simpleResponseTimeDataset
             self._txNameDatasets[txNum] = dataSetGroup
         logger.debug("DEBUG: done building data sets.")
     return self._txNameDatasets[txName]
Example #9
0
def regression(data, regtype=0):
    xAxis = NumberAxis("x")
    xAxis.setAutoRangeIncludesZero(False)
    yAxis = NumberAxis("y")
    yAxis.setAutoRangeIncludesZero(False)

    series = XYSeries("values")
    xmax = xmin = None
    for (x, y) in data:
        series.add(x, y)
        if xmax is None:
            xmax = xmin = x
        else:
            xmax = max(xmax, x)
            xmin = min(xmin, x)

    dataset = XYSeriesCollection()
    dataset.addSeries(series)

    renderer1 = XYDotRenderer()
    plot = XYPlot(dataset, xAxis, yAxis, renderer1)

    if regtype == 1:
        coefficients = Regression.getPowerRegression(dataset, 0)
        curve = PowerFunction2D(coefficients[0], coefficients[1])
        regdesc = "Power Regression"
    else:
        coefficients = Regression.getOLSRegression(dataset, 0)
        curve = LineFunction2D(coefficients[0], coefficients[1])
        regdesc = "Linear Regression"

    regressionData = DatasetUtilities.sampleFunction2D(
        curve, xmin, xmax, 100, "Fitted Regression Line")

    plot.setDataset(1, regressionData)
    renderer2 = XYLineAndShapeRenderer(True, False)
    renderer2.setSeriesPaint(0, Color.blue)
    plot.setRenderer(1, renderer2)

    jfchart = JFreeChart(regdesc, JFreeChart.DEFAULT_TITLE_FONT, plot, True)

    chart = Chart(jfchart)
    chart.coeffs = coefficients

    return chart
Example #10
0
def regression(data, regtype=0):
  xAxis = NumberAxis("x")   
  xAxis.setAutoRangeIncludesZero(False)   
  yAxis = NumberAxis("y")   
  yAxis.setAutoRangeIncludesZero(False)   
   
  series = XYSeries("values"); 
  xmax = xmin = None       
  for (x,y) in data:
    series.add(x, y);
    if xmax is None:
      xmax = xmin = x
    else:
      xmax = max(xmax, x)
      xmin = min(xmin, x)
            
  dataset = XYSeriesCollection()
  dataset.addSeries(series);

  renderer1 = XYDotRenderer()
  plot = XYPlot(dataset, xAxis, yAxis, renderer1)   
      
  if regtype == 1:
    coefficients = Regression.getPowerRegression(dataset, 0)
    curve = PowerFunction2D(coefficients[0], coefficients[1])
    regdesc = "Power Regression"    
  else:
    coefficients = Regression.getOLSRegression(dataset, 0)   
    curve = LineFunction2D(coefficients[0], coefficients[1])
    regdesc = "Linear Regression"   
    
  regressionData = DatasetUtilities.sampleFunction2D(curve, xmin, xmax, 100, 
    "Fitted Regression Line")   
           
  plot.setDataset(1, regressionData)   
  renderer2 = XYLineAndShapeRenderer(True, False)   
  renderer2.setSeriesPaint(0, Color.blue)   
  plot.setRenderer(1, renderer2)
           
  jfchart = JFreeChart(regdesc, JFreeChart.DEFAULT_TITLE_FONT, plot, True);
    
  chart = Chart(jfchart)          
  chart.coeffs = coefficients
    
  return chart
Example #11
0
def graphArbitraryData(data,title=""):
    ''' Creates a line graph of arbitrary data (as opposed to the rather specific format used by graph()). Pass in a 
        list of datapoints, each consisting of a 3-part tuple: (x,y,category). If you're only graphing one sort of
        data (ie you want a graph with one line), you can pass in any constant for category.
        
        Example (one category):
            data = []
            for i in range(10):
                data.append((i,i*i,0))
            graphArbitraryData(data)
        
        Example (multiple categories):
            data = []
            for i in range(30):
                data.append((i,i*i,"squared"))
                data.append((i,i*i*i,"cubed"))
            graphArbitraryData(data)

    '''
    from org.jfree.data.category import DefaultCategoryDataset
    from org.jfree.chart import ChartFactory, ChartFrame, ChartPanel
    from org.jfree.chart.plot import PlotOrientation
    from org.jfree.data.xy import XYSeriesCollection, XYSeries

    datasets = {} # dict of all series

    # First, create the individual series from the data
    for item in data:
        seriesname = str(item[2])
        if seriesname not in datasets:
            datasets[seriesname] = XYSeries(seriesname)
        datasets[seriesname].add(float(item[0]), float(item[1]));

    # Second, add those series to a collection
    datasetcollection = XYSeriesCollection()    
    for key in datasets:
        datasetcollection.addSeries(datasets[key])
        
    chart = ChartFactory.createXYLineChart("","","",datasetcollection,PlotOrientation.VERTICAL,True,True,False) 
    frame = ChartFrame(title, chart);
    frame.pack();
    frame.setVisible(True);
    panel = ChartPanel(chart)
    return chart.getPlot()
Example #12
0
def xy(data, name="", xlabel="", ylabel=""):
    """
  Creates a xy bar chart.         

  *data* is a list of (x,y) tuples
  """
    series = XYSeries(name)
    for x, y in data:
        series.add(x, y)

    dataset = XYSeriesCollection(series)
    if len(data) > 1:
        # hack to set interval width
        x0, x1 = data[0][0], data[1][0]
        dataset.setIntervalWidth(x1 - x0)

    chart = ChartFactory.createXYBarChart(
        None, xlabel, False, ylabel, dataset, PlotOrientation.VERTICAL, True, True, False
    )
    return Chart(chart)
Example #13
0
def xy(data, name='', xlabel='', ylabel=''):
    """
  Creates a xy bar chart.         

  *data* is a list of (x,y) tuples
  """
    series = XYSeries(name)
    for x, y in data:
        series.add(x, y)

    dataset = XYSeriesCollection(series)
    if len(data) > 1:
        # hack to set interval width
        x0, x1 = data[0][0], data[1][0]
        dataset.setIntervalWidth(x1 - x0)

    chart = ChartFactory.createXYBarChart(None, xlabel, False, ylabel, dataset,
                                          PlotOrientation.VERTICAL, True, True,
                                          False)
    return Chart(chart)
Example #14
0
def curve(data, name="", smooth=True, trid=True):
  """
  Creates a curve based on a list of (x,y) tuples.
    
  Setting *smooth* to ``True`` results in a spline renderer renderer is used.

  Setting *trid* to ``True`` results in a 3D plot. In this case the ``smooth``
  argument is ignored.
  """
    
  dataset = XYSeriesCollection()
  xy = XYSeries(name);        
  for d in data:
    xy.add(d[0], d[1])
  dataset.addSeries(xy);
  chart = ChartFactory.createXYLineChart(None, None, None, dataset, 
    PlotOrientation.VERTICAL, True, True, False)

  if smooth:
      chart.getXYPlot().setRenderer(XYSplineRenderer())
  if trid:
      chart.getXYPlot().setRenderer(XYLine3DRenderer())        
    
  return Chart(chart)    
def plot2D(points, Ca, Cb):
	maxIntensity = 255.0
	dataset = XYSeriesCollection()

	seriesNN = XYSeries(channels[Ca+1]+" -ve "+channels[Cb+1]+" -ve")
	seriesPP = XYSeries(channels[Ca+1]+" +ve "+channels[Cb+1]+" +ve")
	seriesNP = XYSeries(channels[Ca+1]+" -ve "+channels[Cb+1]+" +ve")
	seriesPN = XYSeries(channels[Ca+1]+" +ve "+channels[Cb+1]+" -ve")
	for p in points:
		posA = channels[Ca+1] in thresholds and p[Ca]>thresholds[ channels[Ca+1] ]
		posB = channels[Cb+1] in thresholds and p[Cb]>thresholds[ channels[Cb+1] ]
		if posA and posB:
			seriesPP.add(p[Cb], p[Ca])
		elif posA:
			seriesPN.add(p[Cb], p[Ca])
		elif posB:
			seriesNP.add(p[Cb], p[Ca])
		else:
			seriesNN.add(p[Cb], p[Ca])
	dataset.addSeries(seriesNN)
	dataset.addSeries(seriesPN)
	dataset.addSeries(seriesNP)
	dataset.addSeries(seriesPP)
	
	chart = ChartFactory.createScatterPlot( title+" - "+channels[Cb+1]+" vs "+channels[Ca+1], channels[Cb+1], channels[Ca+1], dataset, PlotOrientation.VERTICAL, False,True,False )
	plot = chart.getPlot()
	plot.getDomainAxis().setRange(Range(0.00, maxIntensity), True, False)
	plot.getRangeAxis().setRange(Range(0.00, maxIntensity), True, False)
	renderer = chart.getPlot().getRenderer()
	
	renderer.setSeriesPaint(0, Color(64,64,64)) #NN
	renderer.setSeriesPaint(1, Color(0,255,0)) #PN
	renderer.setSeriesPaint(2, Color(0,0,255)) #NP
	renderer.setSeriesPaint(3, Color(0,255,255)) #PP

	shape = Ellipse2D.Float(-1,-1,3,3)
	renderer.setSeriesShape(0, shape )
	renderer.setSeriesShape(1, shape )
	renderer.setSeriesShape(2, shape )
	renderer.setSeriesShape(3, shape )
	
	frame = ChartFrame(title+" - "+channels[Cb+1]+" vs "+channels[Ca+1], chart)
	frame.setSize(800, 800)
	frame.setLocationRelativeTo(None)
	frame.setVisible(True)
Example #16
0
    def updateChartDataset(self, drawLabels=True):
        dataset = XYSeriesCollection()
        #self._calculateConfArea()
        #dataset.addSeries(self.confArea)
        dataset.addSeries(self.markers)
        plot = self.chart.getPlot()
        rangeAxis = plot.getRangeAxis()
        domainAxis = plot.getDomainAxis()
        rangeAxis.setRange(self.minY, self.maxY)  #change
        domainAxis.setRange(self.minX, self.maxX)
        #plot.setBackgroundPaint(Color.lightGray);
        #plot.setAxisOffset(new RectangleInsets(5.0, 5.0, 5.0, 5.0));
        #plot.setDomainGridlinePaint(Color.white);
        #plot.setRangeGridlinePaint(Color.white);

        markerRenderer = XYLineAndShapeRenderer(False, True)
        #print self.markerColor
        markerRenderer.setSeriesPaint(0, self.markerColor)
        markerRenderer.setSeriesShape(0, Ellipse2D.Double(-3, -3, 6, 6))
        #markerRenderer.setToolTipGenerator(FDistToolTipGenerator(self.pointNames))
        plot.setRenderer(0, markerRenderer)
        plot.setDataset(0, dataset)
        dataset = XYSeriesCollection()
        if self.drawCI:
            dataset = YIntervalSeriesCollection()
            CIRenderer = DeviationRenderer(True, False)
            # CIRenderer.setOutline(True)
            # CIRenderer.setRoundXCoordinates(True)
            dataset.addSeries(self.bottom)
            dataset.addSeries(self.top)
            dataset.addSeries(self.limit)
            CIRenderer.setSeriesFillPaint(0, self.balColor)
            CIRenderer.setSeriesFillPaint(1, self.neuColor)
            CIRenderer.setSeriesFillPaint(2, self.posColor)
            CIRenderer.setSeriesPaint(0, self.balColor)
            CIRenderer.setSeriesPaint(1, self.neuColor)
            CIRenderer.setSeriesPaint(2, self.posColor)
            plot.setDataset(1, dataset)
            plot.setRenderer(1, CIRenderer)
        plot.setDataset(1, dataset)
        if drawLabels:
            self.drawLabels()
Example #17
0
    def updateChartDataset(self, drawLabels=True):
        dataset = XYSeriesCollection()
        #self._calculateConfArea()
        #dataset.addSeries(self.confArea)
        dataset.addSeries(self.markers)
        plot = self.chart.getPlot()
        rangeAxis = plot.getRangeAxis()
        domainAxis = plot.getDomainAxis()
        rangeAxis.setRange(self.minY, self.maxY) #change
        domainAxis.setRange(0.0, self.maxX)
        #plot.setBackgroundPaint(Color.lightGray);
        #plot.setAxisOffset(new RectangleInsets(5.0, 5.0, 5.0, 5.0));
        #plot.setDomainGridlinePaint(Color.white);
        #plot.setRangeGridlinePaint(Color.white);

        markerRenderer = XYLineAndShapeRenderer(False, True)
        #print self.markerColor
        markerRenderer.setSeriesPaint(0, self.markerColor)
        markerRenderer.setSeriesShape(0, Ellipse2D.Double(-3, -3, 6, 6))
        #markerRenderer.setToolTipGenerator(FDistToolTipGenerator(self.pointNames))
        plot.setRenderer(0, markerRenderer)
        plot.setDataset(0, dataset)
        dataset = XYSeriesCollection()
        if self.drawCI:
            dataset = YIntervalSeriesCollection()
            CIRenderer = DeviationRenderer(True, False)
            #CIRenderer.setOutline(True)
            #CIRenderer.setRoundXCoordinates(True)
            dataset.addSeries(self.bottom)
            dataset.addSeries(self.top)
            dataset.addSeries(self.limit)
            CIRenderer.setSeriesFillPaint(0, self.balColor)
            CIRenderer.setSeriesFillPaint(1, self.neuColor)
            CIRenderer.setSeriesFillPaint(2, self.posColor)
            CIRenderer.setSeriesPaint(0, self.balColor)
            CIRenderer.setSeriesPaint(1, self.neuColor)
            CIRenderer.setSeriesPaint(2, self.posColor)
            plot.setDataset(1, dataset)
            plot.setRenderer(1, CIRenderer)
        plot.setDataset(1, dataset)
        if drawLabels: self.drawLabels()
Example #18
0
 def getDataSets(self, txName):
     '''
     For a given txNum name, return a set of JFreeChart datasets
     with data on tx/sec, response times, and bandwidth.
     '''
     logger.debug("getting data sets for " + txName)
     if self._txNameDatasets == None:
         logger.warn("Building data sets.")
         txSecDataset = None
         # build 'em
         self._txNameDatasets = {}
         txNums = self._summaryData.getTxNumNameMap().keys()
         for txNum in txNums:
             logger.debug("DEBUG: building DS for " + txNum)
             dataSetGroup = {}
             txSecDataset = XYSeriesCollection(
             )  # not returning a new object
             bandwidthDataSet = XYSeriesCollection()
             simpleResponseTimeDataset = XYSeriesCollection()
             responseTimeDataset = DefaultTableXYDataset()
             txSecPassSeries = XYSeries("passed")
             txSecFailSeries = XYSeries("failed")
             responseTimeSeries = XYSeries("seconds")
             finishTimeSeries = XYSeries("complete", True, False)
             resolveHostSeries = XYSeries("resolveHost", True, False)
             connectSeries = XYSeries("connect", True, False)
             firstByteSeries = XYSeries("firstByte", True, False)
             bandwidthSeries = XYSeries("KB/sec")
             for bucket in self.bucketList:
                 txSecPass = bucket.getTxSecPass(txNum)
                 txSecPassSeries.add(bucket.getStartTime() / 1000.0,
                                     txSecPass)
                 txSecFail = bucket.getTxSecFail(txNum)
                 txSecFailSeries.add(bucket.getStartTime() / 1000.0,
                                     txSecFail)
                 responseTimeSeries.add(bucket.getStartTime() / 1000.0,
                                        bucket.getMeanResponseTime(txNum))
                 if ga.constants.VORPAL.getPlugin("analyzer").isHTTP():
                     bandwidthSeries.add(
                         bucket.getStartTime() / 1000.0,
                         bucket.getMeanThroughputKBSec(txNum))
                     finishTimeSeries.add(bucket.getStartTime() / 1000.0,
                                          bucket.getMeanFinishTime(txNum))
                     resolveHostSeries.add(
                         bucket.getStartTime() / 1000.0,
                         bucket.getMeanResolveHostTime(txNum))
                     connectSeries.add(bucket.getStartTime() / 1000.0,
                                       bucket.getMeanConnectTime(txNum))
                     firstByteSeries.add(bucket.getStartTime() / 1000.0,
                                         bucket.getMeanFirstByteTime(txNum))
             txSecDataset.addSeries(txSecPassSeries)
             txSecDataset.addSeries(txSecFailSeries)
             responseTimeDataset.addSeries(resolveHostSeries)
             responseTimeDataset.addSeries(connectSeries)
             responseTimeDataset.addSeries(firstByteSeries)
             responseTimeDataset.addSeries(finishTimeSeries)
             simpleResponseTimeDataset.addSeries(responseTimeSeries)
             bandwidthDataSet.addSeries(bandwidthSeries)
             dataSetGroup[TX_SEC_KEY] = txSecDataset
             dataSetGroup[FULL_RESPONSE_TIME_KEY] = responseTimeDataset
             dataSetGroup[THROUGHPUT_KEY] = bandwidthDataSet
             dataSetGroup[
                 SIMPLE_RESPONSE_TIME_KEY] = simpleResponseTimeDataset
             self._txNameDatasets[txNum] = dataSetGroup
         logger.debug("DEBUG: done building data sets.")
     return self._txNameDatasets[txName]
Example #19
0
def scatter(data,x=None,y=None,**kwargs):
    ''' Creates a scatter plot comparing two elements. At minimum, takes a collection of data.
        The second and third arguments, if they exist, are treated as the two elements to 
        compare. If these arguments do not exist, the first two elements in the list are 
        compared. If an optional regress=True argument is present, superimposes a linear 
        regression for each series and prints some related info (R-value etc). 
        Note that scatter plots can be zoomed with the mouse. Returns the plot object in
        case you want to customize the graph in some way. Takes an optional showMissing argument 
        which determines whether missing values (-9999.0) should be displayed. 
        Examples: scatter(data), scatter(data,"tmin","tmax",regress=True)
    '''
    from org.jfree.data.xy import XYSeriesCollection,XYSeries
    from org.jfree.data import UnknownKeyException
    from org.jfree.chart import ChartFactory,ChartFrame
    from org.jfree.chart.plot import PlotOrientation,DatasetRenderingOrder
    from org.jfree.chart.renderer.xy import XYLineAndShapeRenderer
    from java.awt import Color
    
    regress=kwargs.get('regress',False)
    showMissing=kwargs.get('showMissing',False)
    
    # Try to be flexible about element parameters
    if x is not None: x = findElement(x).name
    if y is not None: y = findElement(y).name

    # Create a dataset from the data
    collection = XYSeriesCollection()
    for ob in data.groupedByObservation().items():
        key,values = ob
        name = str(key[0])
        if x==None:
            x = values[0].element.name
        try:
            xFact = (i for i in values if i.element.name == x).next()
        except StopIteration: # missing value
            continue
        xval  = xFact.value
        if xval in missingValues and not showMissing: continue  
        if y==None:
            try:
                y = values[1].element.name
            except IndexError:
                raise Exception("Error! Your data request returned only 1 value per observation. " 
                                "Must have 2 values to generate a scatter plot.")
        try:
            yFact = (i for i in values if i.element.name == y).next()
        except StopIteration: # missing value
            continue
        yval  = yFact.value
        if yval in missingValues and not showMissing: continue  
        
        try: 
            series = collection.getSeries(name)
        except UnknownKeyException:
            collection.addSeries(XYSeries(name))
            series = collection.getSeries(name)
        
        series.add(float(xval),float(yval))

    # Create chart from dataset        
    chart = ChartFactory.createScatterPlot( "", x, y, collection, PlotOrientation.VERTICAL,
                                            True, True, False );
    plot = chart.getPlot()
    frame = ChartFrame("Scatter Plot", chart);
    frame.pack();
    frame.setVisible(True);

    # Superimpose regression if desired
    if regress:
        regressioncollection = XYSeriesCollection()
        for series in collection.getSeries():
            regression = _getregression(series)
            x1 = series.getMinX()
            y1 = regression.predict(x1)
            x2 = series.getMaxX()
            y2 = regression.predict(x2)
            regressionseries = XYSeries(series.getKey())
            regressionseries.add(float(x1),float(y1))
            regressionseries.add(float(x2),float(y2))
            regressioncollection.addSeries(regressionseries)

            print series.getKey(),":"
            print "  R:            %8.4f" % regression.getR()
            print "  R-squared:    %8.4f" % regression.getRSquare()
            print "  Significance: %8.4f" % regression.getSignificance()
            print
            
        plot.setDataset(1,regressioncollection)
        regressionRenderer = XYLineAndShapeRenderer(True,False)
        plot.setRenderer(1,regressionRenderer)
        plot.setDatasetRenderingOrder(DatasetRenderingOrder.FORWARD);
        
        colors = [0xec0000,0x58b911,0x6886ea,0xedd612,0xa93bb9,0xffb71b,0xe200df,0x1de2b6,0xdc91db,0x383838,0xb09344,0x4ea958,0xd78c9e,0x64008d,0xb0c95b]
        mainRenderer = plot.getRenderer(0)
        for i in range(collection.getSeriesCount()):
            try:
                mainRenderer.setSeriesPaint(i,Color(colors[i]))
                regressionRenderer.setSeriesPaint(i,Color(colors[i]))
            except IndexError: # Finite # of colors in the color array; beyond that let jfreechart pick
                break
        '''
        # Jump through some hoops to ensure regressions are same color as scatters for each series.
        # Initially: doesn't work because series are not indexed the same. And I don't see a way
        # to get the actual series from the renderer in order to compare names or something.
        mainRenderer = plot.getRenderer(0)
        print "Renderer is",type(mainRenderer)
        index = 0
        paint = mainRenderer.lookupSeriesPaint(index)
        print "Paint is",type(paint)
        while (paint is not None):
            print "Setting paint."
            regressionRenderer.setSeriesPaint(index,paint)
            index += 1
            paint = mainRenderer.getSeriesPaint(index)
        '''
        return plot