def allBooleanType(self, arguments, atleast=None, atmost=None): """Helper method to require only boolean arguments. @type arguments: list of DataColumns @param arguments: Input DataColumns representing the pre-evaluated arguments of the function. @type atleast: int or None @param atleast: If None, no minimum number of arguments; otherwise, require at least this many. @type atmost: int or None @param atmost: If None, no minimum number of arguments; otherwise, require at most this many. @raise PmmlValidationError: If the condition is not met, raise an error. Otherwise, silently pass. """ if atleast is not None and len(arguments) < atleast: raise defs.PmmlValidationError( "Function \"%s\" requires at least %d arguments" % (self.name, atleast)) if atmost is not None and len(arguments) > atmost: raise defs.PmmlValidationError( "Function \"%s\" requires at most %d arguments" % (self.name, atmost)) if all(argument.fieldType.dataType == "boolean" for argument in arguments): return self._typeReverseMap[BOOL] else: raise defs.PmmlValidationError( "Function \"%s\" requires all arguments to be boolean" % self.name)
def unserializeState(self): """Load the contents of this SerializedState into a new DataTableState. @rtype: DataTableState @return: The DataTableState that can be used as an input to new calculations. """ output = DataTableState() for serializedStateKey in self.childrenOfTag("SerializedStateKey"): if len(serializedStateKey ) == 1 and serializedStateKey.text is None: value = serializedStateKey.getchildren()[0] elif len(serializedStateKey ) == 0 and serializedStateKey.text is not None: value = json.loads(serializedStateKey.text, object_hook=self.customJsonDecoder) else: raise defs.PmmlValidationError( "SerializedStateKey must contain either an XML tree xor JSON text" ) output[serializedStateKey["key"]] = value return output
def select(self, dataTable, functionTable, performanceTable): """Evaluate the expression or predicate, given input data and a function table. @type dataTable: DataTable @param dataTable: Contains the data to plot. @type functionTable: FunctionTable @param functionTable: Defines functions that may be used to transform data for plotting. @type performanceTable: PerformanceTable @param performanceTable: Measures and records performance (time and memory consumption) of the drawing process. @rtype: 1d Numpy array of bool @return: The result of the expression or predicate as a Numpy mask. """ predicate = self.childOfClass(PmmlPredicate) if predicate is not None: return predicate.evaluate(dataTable, functionTable, performanceTable) expression = self.childOfClass(PmmlExpression) dataColumn = expression.evaluate(dataTable, functionTable, performanceTable) if not dataColumn.fieldType.isboolean(): raise defs.PmmlValidationError( "PlotSelection must evaluate to boolean, not %r" % dataColumn.fieldType) dataColumn._unlock() if dataColumn.mask is not None: NP("logical_and", dataColumn.data, NP(dataColumn.mask == defs.VALID), dataColumn.data) return dataColumn.data
def fieldTypeFromSignature(self, arguments): # arguments are dataColumns """Helper method to derive the resulting FieldType from the class's signature. This method assumes that the class has a class attribute named C{signatures} that maps TYPEOBJECT tuples (input type signature) to TYPEOBJECTs (output type). @type arguments: list of DataColumns @param arguments: Input DataColumns representing the pre-evaluated arguments of the function. @rtype: FieldType @return: The FieldType corresponding to the output type. @raise PmmlValidationError: If no matching signature is found, raise an error. """ signature = [] for argument in arguments: signature.append(self._typeMap[argument.fieldType.dataType]) outputType = self.signatures.get(tuple(signature)) if outputType is None: raise defs.PmmlValidationError( "Function \"%s\" has no signature matching its arguments" % self.name) else: return self._typeReverseMap[outputType]
def evaluate(self, dataTable, functionTable, performanceTable): """Evaluate the expression, given input data and a function table. @type dataTable: DataTable @param dataTable: Contains the data to plot. @type functionTable: FunctionTable @param functionTable: Defines functions that may be used to transform data for plotting. @type performanceTable: PerformanceTable @param performanceTable: Measures and records performance (time and memory consumption) of the drawing process. @rtype: DataColumn @return: The result of the expression as a DataColumn. @raise PmmlValidationError: If the expression is not numeric, this method raises an error. """ dataColumn = self.childOfClass(PmmlExpression).evaluate( dataTable, functionTable, performanceTable) if not dataColumn.fieldType.isnumeric( ) and not dataColumn.fieldType.istemporal(): raise defs.PmmlValidationError( "PlotNumericExpression must evaluate to a number, not %s" % dataColumn.fieldType) return dataColumn
def determineScaleBins(numBins, low, high, array): """Determine the C{numBins}, C{low}, and C{high} of the histogram from explicitly set values where available and implicitly derived values where necessary. Explicitly set values always override implicit values derived from the dataset. - C{low}, C{high} implicit values are the extrema of the dataset. - C{numBins} implicit value is the Freedman-Diaconis heuristic for number of histogram bins. @type numBins: int or None @param numBins: Input number of bins. @type low: number or None @param low: Low edge. @type high: number or None @param high: High edge. @type array: 1d Numpy array of numbers @param array: Dataset to use to implicitly derive values. @rtype: 3-tuple @return: C{numBins}, C{low}, C{high} """ generateLow = (low is None) generateHigh = (high is None) if generateLow: low = float(array.min()) if generateHigh: high = float(array.max()) if low == high: low, high = low - 1.0, high + 1.0 elif high < low: if generateLow: low = high - 1.0 elif generateHigh: high = low + 1.0 else: raise defs.PmmlValidationError( "PlotHistogram attributes low and high must be in the right order: low = %g, high = %g" % (low, high)) else: if generateLow and generateHigh: low, high = low - 0.2 * (high - low), high + 0.2 * (high - low) elif generateLow: low = low - 0.2 * (high - low) elif generateHigh: high = high + 0.2 * (high - low) if numBins is None: # the Freedman-Diaconis rule q1, q3 = NP("percentile", array, [25.0, 75.0]) binWidth = 2.0 * (q3 - q1) / math.pow(len(array), 1.0 / 3.0) if binWidth > 0.0: numBins = max(10, int(math.ceil((high - low) / binWidth))) else: numBins = 10 return numBins, low, high
def functionAverage(self, dataColumn, whereMask, groupSelection, getstate, setstate): """Averages rows in a DataColumn, possibly with an SQL where mask and groupField. @type dataColumn: DataColumn @param dataColumn: The input data column. @type whereMask: 1d Numpy array of bool, or None @param whereMask: The result of the SQL where selection. @type groupSelection: 1d Numpy array of bool, or None. @param groupSelection: Rows corresponding to a particular value of the groupField. @type getstate: callable function @param getstate: Retrieve staring values from the DataTableState. @type setstate: callable function @param setstate: Store ending values to the DataTableState. @rtype: DataColumn @return: A column of averaged rows. """ fieldType = FakeFieldType("double", "continuous") if dataColumn.fieldType.dataType not in ("integer", "float", "double"): raise defs.PmmlValidationError( "Aggregate function \"average\" requires a numeric input field: \"integer\", \"float\", \"double\"" ) denominator = NP("ones", len(dataColumn), dtype=fieldType.dtype) if dataColumn.mask is not None: NP("logical_and", denominator, NP(dataColumn.mask == defs.VALID), denominator) if whereMask is not None: NP("logical_and", denominator, whereMask, denominator) if groupSelection is not None: NP("logical_and", denominator, groupSelection, denominator) numerator = NP("multiply", denominator, dataColumn.data) if getstate is not None and len(dataColumn) > 0: startingState = getstate() if startingState is not None: startingNumerator, startingDenominator = startingState numerator[0] += startingNumerator denominator[0] += startingDenominator numerator = NP("cumsum", numerator) denominator = NP("cumsum", denominator) data = NP(numerator / denominator) mask = NP(NP("logical_not", NP("isfinite", data)) * defs.INVALID) if not mask.any(): mask = None if setstate is not None and len(dataColumn) > 0: setstate((numerator[-1], denominator[-1])) return DataColumn(fieldType, data, mask)
def evaluate(self, dataTable, functionTable, performanceTable, arguments): performanceTable.begin("built-in \"%s\"" % self.name) fieldType = self._typeReverseMap[BOOL] if len(arguments) != 2: raise defs.PmmlValidationError( "Function \"like\" requires exactly two arguments") if isinstance(arguments[1], Constant): pattern = arguments[1].evaluateOne(convertType=False) try: pattern = re.compile(pattern) except sre_constants as err: raise defs.PmmlValidationError( "Could not compile regex pattern \"%s\": %s" % (pattern, str(err))) else: raise defs.PmmlValidationError( "Function \"like\" requires its second argument (the regex pattern) to be a Constant" ) performanceTable.pause("built-in \"%s\"" % self.name) test = arguments[0].evaluate(dataTable, functionTable, performanceTable) performanceTable.unpause("built-in \"%s\"" % self.name) if test.fieldType.optype == "continuous": d = test.data data = NP("fromiter", (re.match(pattern, d[i]) is not None for i in xrange(len(dataTable))), dtype=fieldType.dtype, count=len(dataTable)) else: d = test.data ds = test.fieldType.valueToString data = NP("fromiter", (re.match(pattern, ds(d[i])) is not None for i in xrange(len(dataTable))), dtype=fieldType.dtype, count=len(dataTable)) performanceTable.end("built-in \"%s\"" % self.name) return DataColumn(fieldType, data, test.mask)
def prepare(self, state, dataTable, functionName, performanceTable, plotRange): """Prepare a plot element for drawing. This stage consists of calculating all quantities and determing the bounds of the data. These bounds may be unioned with bounds from other plot elements that overlay this plot element, so the drawing (which requires a finalized coordinate system) cannot begin yet. This method modifies C{plotRange}. @type state: ad-hoc Python object @param state: State information that persists long enough to use quantities computed in C{prepare} in the C{draw} stage. This is a work-around of lxml's refusal to let its Python instances maintain C{self} and it is unrelated to DataTableState. @type dataTable: DataTable @param dataTable: Contains the data to plot. @type functionTable: FunctionTable @param functionTable: Defines functions that may be used to transform data for plotting. @type performanceTable: PerformanceTable @param performanceTable: Measures and records performance (time and memory consumption) of the drawing process. @type plotRange: PlotRange @param plotRange: The bounding box of plot coordinates that this function will expand. """ self._saveContext(dataTable) x1 = float(self["x1"]) y1 = float(self["y1"]) x2 = float(self["x2"]) y2 = float(self["y2"]) if x1 >= x2 or y1 >= y2: raise defs.PmmlValidationError( "x1 must be less than x2 and y1 must be less than y2") if plotRange.xStrictlyPositive or plotRange.yStrictlyPositive: raise defs.PmmlValidationError( "PlotSvgContent can only be properly displayed in linear coordinates" ) plotRange.xminPush(x1, self.fieldTypeNumeric, sticky=True) plotRange.yminPush(y1, self.fieldTypeNumeric, sticky=True) plotRange.xmaxPush(x2, self.fieldTypeNumeric, sticky=True) plotRange.ymaxPush(y2, self.fieldTypeNumeric, sticky=True)
def zValue(self, testDistributions, fieldName, dataColumn, state, performanceTable): """Calculate the score of a zValue TestStatistic. @type testDistributions: PmmlBinding @param testDistributions: The <TestDistributions> element. @type fieldName: string @param fieldName: The field name (for error messages). @type dataColumn: DataColumn @param dataColumn: The field. @type state: DataTableState @param state: The persistent state object (not used). @type performanceTable: PerformanceTable or None @param performanceTable: A PerformanceTable for measuring the efficiency of the calculation. @rtype: dict @return: A dictionary mapping PMML "feature" strings to DataColumns; zValue only defines the None key ("predictedValue"). """ if dataColumn.fieldType.dataType in ("object", "string", "boolean", "date", "time", "dateTime"): raise TypeError( "Field \"%s\" has dataType \"%s\", which is incompatible with BaselineModel.zValue" % (fieldName, dataColumn.fieldType.dataType)) distributions = testDistributions.xpath( "pmml:Baseline/*[@mean and @variance]") if len(distributions) == 0: raise defs.PmmlValidationError( "BaselineModel zValue requires a distribution with a mean and a variance" ) distribution = distributions[0] mean = float(distribution.get("mean")) variance = float(distribution.get("variance")) if variance <= 0.0: raise defs.PmmlValidationError( "Variance must be positive, not %g" % variance) return { None: DataColumn(self.scoreType, NP(NP(dataColumn.data - mean) / math.sqrt(variance)), dataColumn.mask) }
def draw(self, dataTable, functionTable, performanceTable, rowIndex, colIndex, cellContents, labelAttributes, plotDefinitions): """Draw the plot legend content, which is more often text than graphics. @type dataTable: DataTable @param dataTable: Contains the data to describe, if any. @type functionTable: FunctionTable @param functionTable: Defines functions that may be used to transform data. @type performanceTable: PerformanceTable @param performanceTable: Measures and records performance (time and memory consumption) of the drawing process. @type rowIndex: int @param rowIndex: Row number of the C{cellContents} to fill. @type colIndex: int @param colIndex: Column number of the C{cellContents} to fill. @type cellContents: dict @param cellContents: Dictionary that maps pairs of integers to SVG graphics to draw. @type labelAttributes: CSS style dict @param labelAttributes: Style properties that are defined at the level of the legend and must percolate down to all drawables within the legend. @type plotDefinitions: PlotDefinitions @type plotDefinitions: The dictionary of key-value pairs that forms the <defs> section of the SVG document. @rtype: 2-tuple @return: The next C{rowIndex} and C{colIndex} in the sequence. """ svg = SvgBinding.elementMaker svgId = self.get("svgId") if svgId is None: output = svg.g() else: output = svg.g(**{"id": svgId}) inlineSvg = self.getchildren() fileName = self.get("fileName") if len(inlineSvg) == 1 and fileName is None: svgBinding = copy.deepcopy(inlineSvg[0]) elif len(inlineSvg) == 0 and fileName is not None: svgBinding = SvgBinding.loadXml(fileName) else: raise defs.PmmlValidationError("PlotLegendSvg should specify an inline SVG or a fileName but not both or neither") sx1, sy1, sx2, sy2 = PlotSvgAnnotation.findSize(svgBinding) nominalHeight = sy2 - sy1 nominalWidth = sx2 - sx1 # TODO: set this correctly from the text height rowHeight = 30.0 # output["transform"] = "translate(%r, %r) scale(%r, %r)" % (-sx1, -sy1, rowHeight/float(sx2 - sx1), rowHeight/float(sy2 - sy1)) output["transform"] = "translate(%r, %r) scale(%r, %r)" % (-sx1 - 0.5*nominalWidth*rowHeight/nominalHeight, -sy1 - 0.75*rowHeight, rowHeight/nominalHeight, rowHeight/nominalHeight) output.append(svgBinding) cellContents[rowIndex, colIndex] = svg.g(output) cellContents[rowIndex, colIndex].text = " " # TODO: set the width correctly, too colIndex += 1 return rowIndex, colIndex
def _checkIntervals(self, data, mask): intervals = self.intervals if len(intervals) == 0: return data, mask # innocent until proven guilty invalid = NP("zeros", len(data), dtype=NP.dtype(bool)) for interval in intervals: closure = interval["closure"] leftMargin = interval.get("leftMargin") rightMargin = interval.get("rightMargin") if leftMargin is not None: try: leftMargin = self.stringToValue(leftMargin) except ValueError: raise defs.PmmlValidationError("Improper value in Interval leftMargin specification: \"%s\"" % leftMargin) if closure in ("openClosed", "openOpen"): invalid[NP(data <= leftMargin)] = True elif closure in ("closedOpen", "closedClosed"): invalid[NP(data < leftMargin)] = True if rightMargin is not None: try: rightMargin = self.stringToValue(rightMargin) except ValueError: raise defs.PmmlValidationError("Improper value in Interval rightMargin specification: \"%s\"" % rightMargin) if closure in ("openOpen", "closedOpen"): invalid[NP(data >= rightMargin)] = True elif closure in ("openClosed", "closedClosed"): invalid[NP(data > rightMargin)] = True if not invalid.any(): return data, mask if mask is None: return data, NP(invalid * defs.INVALID) else: NP("logical_and", invalid, NP(mask == defs.VALID), invalid) # only change what wasn't already marked as MISSING mask[invalid] = defs.INVALID return data, mask
def values(self, convertType=False): """Extract values from the PMML and represent them in a Pythonic form. @type convertType: bool @param convertType: If False, return a list of strings; if True, convert the type of the values. @rtype: list @return: List of values. """ output = [] if not convertType or self["type"] == "string": if self.text is not None: for word in re.finditer(self._re_word, self.text): one, two, three = word.groups() if two is not None: output.append(two.replace(r'\"', '"')) elif one == r'""': output.append("") else: output.append(one) elif self["type"] == "int": if self.text is not None: try: output = [int(x) for x in self.text.split()] except ValueError as err: raise defs.PmmlValidationError( "Array of type int has a badly formatted value: %s" % str(err)) elif self["type"] == "real": if self.text is not None: try: output = [float(x) for x in self.text.split()] except ValueError as err: raise defs.PmmlValidationError( "Array of type real has a badly formatted value: %s" % str(err)) return output
def postValidate(self): """After XSD validation, check the version of the document. The custom ODG version of this class checks for "4.1-odg", to avoid confusion among serialized models. """ if self.get("version") != self.version: raise defs.PmmlValidationError( "PMML version is \"%s\" when \"%s\" is expected for this class" % (self.get("version"), self.version))
def parse(cls, text): """Parse a formula, producing an internal syntax tree. @type text: string @param text: The formula to parse. @rtype: Formula.List, Formula.Constant, Formula.FieldRef, or Formula.Apply @return: A syntax tree represented by nested class instances. """ if text is None or text.strip() == "": raise defs.PmmlValidationError("Formula is empty") result = cls._parse(text) if len(result) != 1: raise defs.PmmlValidationError("Formula evaluates to %d expressions, rather than 1" % len(result)) result = result[0] if isinstance(result, cls.List): raise defs.PmmlValidationError("Formula evaluates to a list, rather than Constant, FieldRef, or Apply") return result
def draw(self, state, plotCoordinates, plotDefinitions, performanceTable): """Draw the plot element. This stage consists of creating an SVG image of the pre-computed data. @type state: ad-hoc Python object @param state: State information that persists long enough to use quantities computed in C{prepare} in the C{draw} stage. This is a work-around of lxml's refusal to let its Python instances maintain C{self} and it is unrelated to DataTableState. @type plotCoordinates: PlotCoordinates @param plotCoordinates: The coordinate system in which this plot element will be placed. @type plotDefinitions: PlotDefinitions @type plotDefinitions: The dictionary of key-value pairs that forms the <defs> section of the SVG document. @type performanceTable: PerformanceTable @param performanceTable: Measures and records performance (time and memory consumption) of the drawing process. @rtype: SvgBinding @return: An SVG fragment representing the fully drawn plot element. """ svg = SvgBinding.elementMaker x1 = float(self["x1"]) y1 = float(self["y1"]) x2 = float(self["x2"]) y2 = float(self["y2"]) inlineSvg = self.getchildren() fileName = self.get("fileName") if len(inlineSvg) == 1 and fileName is None: svgBinding = inlineSvg[0] elif len(inlineSvg) == 0 and fileName is not None: svgBinding = SvgBinding.loadXml(fileName) else: raise defs.PmmlValidationError( "PlotSvgContent should specify an inline SVG or a fileName but not both or neither" ) sx1, sy1, sx2, sy2 = PlotSvgAnnotation.findSize(svgBinding) subCoordinates = PlotCoordinatesWindow(plotCoordinates, sx1, sy1, sx2, sy2, x1, y1, x2 - x1, y2 - y1) tx0, ty0 = subCoordinates(0.0, 0.0) tx1, ty1 = subCoordinates(1.0, 1.0) transform = "translate(%r, %r) scale(%r, %r)" % (tx0, ty0, tx1 - tx0, ty1 - ty0) attribs = {"transform": transform} svgId = self.get("svgId") if svgId is not None: attribs["id"] = svgId if "style" in svgBinding.attrib: attribs["style"] = svgBinding.attrib["style"] return svg.g(*(copy.deepcopy(svgBinding).getchildren()), **attribs)
def _checkValues(self, data, mask): values = self.values if len(values) == 0: return data, mask if mask is None: missing = NP("zeros", len(data), dtype=NP.dtype(bool)) invalid = NP("zeros", len(data), dtype=NP.dtype(bool)) else: missing = NP(mask == defs.MISSING) invalid = NP(mask == defs.INVALID) valid = NP("zeros", len(data), dtype=NP.dtype(bool)) numberOfValidSpecified = 0 for value in values: v = value.get("value") displayValue = value.get("displayValue") if displayValue is not None: self._displayValue[v] = displayValue prop = value.get("property", "valid") try: v2 = self.stringToValue(v) except ValueError: raise defs.PmmlValidationError("Improper value in Value specification: \"%s\"" % v) if prop == "valid": NP("logical_or", valid, NP(data == v2), valid) numberOfValidSpecified += 1 elif prop == "missing": NP("logical_or", missing, NP(data == v2), missing) elif prop == "invalid": NP("logical_or", invalid, NP(data == v2), invalid) if numberOfValidSpecified > 0: # guilty until proven innocent NP("logical_and", valid, NP("logical_not", missing), valid) if valid.all(): return data, None mask = NP(NP("ones", len(data), dtype=defs.maskType) * defs.INVALID) mask[missing] = defs.MISSING mask[valid] = defs.VALID else: # innocent until proven guilty NP("logical_and", invalid, NP("logical_not", missing), invalid) if not NP("logical_or", invalid, missing).any(): return data, None mask = NP("zeros", len(data), dtype=defs.maskType) mask[missing] = defs.MISSING mask[invalid] = defs.INVALID return data, mask
def checkStyleProperties(self): """Verify that all properties currently requested in the C{style} attribute are in the legal C{styleProperties} list. @raise PmmlValidationError: If the list contains an unrecognized style property name, raise an error. Otherwise, silently pass. """ style = self.get("style") if style is not None: for name in PlotStyle.toDict(style).keys(): if name not in self.styleProperties: raise defs.PmmlValidationError( "Unrecognized style property: \"%s\"" % name)
def applyMapMissingTo(fieldType, data, mask, mapMissingTo, overwrite=False): """Replace MISSING values with a given substitute. This function does not modify the original data (unless C{overwrite} is True), but it returns a substitute. Example use:: data, mask = dataColumn.data, dataColumn.mask data, mask = FieldCastMethods.applyMapMissingTo(dataColumn.fieldType, data, mask, "-999") return DataColumn(dataColumn.fieldType, data, mask) It can also be used in conjunction with other FieldCastMethods. @type fieldType: FieldType @param fieldType: The data fieldType (to interpret C{mapMissingTo}). @type data: 1d Numpy array @param data: The data. @type mask: 1d Numpy array of dtype defs.maskType, or None @param mask: The mask. @type mapMissingTo: string @param mapMissingTo: The replacement value, represented as a string (e.g. directly from a PMML attribute). @type overwrite: bool @param overwrite: If True, temporarily unlike and overwrite the original mask. @rtype: 2-tuple of 1d Numpy arrays @return: The new data and mask. """ if mask is None: return data, mask if mapMissingTo is not None: selection = NP(mask == defs.MISSING) try: mappedValue = fieldType.stringToValue(mapMissingTo) except ValueError as err: raise defs.PmmlValidationError("mapMissingTo string \"%s\" cannot be cast as %r: %s" % (mapMissingTo, fieldType, str(err))) if overwrite: data.setflags(write=True) mask.setflags(write=True) else: data = NP("copy", data) mask = NP("copy", mask) data[selection] = mappedValue mask[selection] = defs.VALID if not mask.any(): mask = None return data, mask
def establishBinType(fieldType, intervals, values): """Determine the type of binning to use for a histogram with the given FieldType, Intervals, and Values. @type fieldType: FieldType @param fieldType: The FieldType of the plot expression. @type intervals: list of PmmlBinding @param intervals: The <Interval> elements; may be empty. @type values: list of PmmlBinding @param values: The <Value> elements; may be empty. @rtype: string @return: One of "nonuniform", "explicit", "unique", "scale". """ if len(intervals) > 0: if not fieldType.isnumeric() and not fieldType.istemporal(): raise defs.PmmlValidationError( "Explicit Intervals are intended for numerical data, not %r" % fieldType) return "nonuniform" elif len(values) > 0: if not fieldType.isstring(): raise defs.PmmlValidationError( "Explicit Values are intended for string data, not %r" % fieldType) return "explicit" elif fieldType.isstring(): return "unique" else: if not fieldType.isnumeric() and not fieldType.istemporal(): raise defs.PmmlValidationError( "PlotHistogram requires numerical or string data, not %r" % fieldType) return "scale"
def functionSum(self, dataColumn, whereMask, groupSelection, getstate, setstate): """Adds up rows in a DataColumn, possibly with an SQL where mask and groupField. @type dataColumn: DataColumn @param dataColumn: The input data column. @type whereMask: 1d Numpy array of bool, or None @param whereMask: The result of the SQL where selection. @type groupSelection: 1d Numpy array of bool, or None. @param groupSelection: Rows corresponding to a particular value of the groupField. @type getstate: callable function @param getstate: Retrieve staring values from the DataTableState. @type setstate: callable function @param setstate: Store ending values to the DataTableState. @rtype: DataColumn @return: A column of added rows. """ fieldType = FakeFieldType("double", "continuous") if dataColumn.fieldType.dataType not in ("integer", "float", "double"): raise defs.PmmlValidationError( "Aggregate function \"sum\" requires a numeric input field: \"integer\", \"float\", \"double\"" ) ones = NP("ones", len(dataColumn), dtype=fieldType.dtype) if dataColumn.mask is not None: NP("logical_and", ones, NP(dataColumn.mask == defs.VALID), ones) if whereMask is not None: NP("logical_and", ones, whereMask, ones) if groupSelection is not None: NP("logical_and", ones, groupSelection, ones) NP("multiply", ones, dataColumn.data, ones) if getstate is not None and len(dataColumn) > 0: startingState = getstate() if startingState is not None: ones[0] += startingState data = NP("cumsum", ones) if setstate is not None and len(dataColumn) > 0: setstate(data[-1]) return DataColumn(fieldType, data, None)
def checkRoles(self, expected): """Helper method to verify that all expected roles are present. (Some plot types use PlotFormula/PlotExpression/PlotNumericExpression elements with predefined "roles" to specify the contents of the axes.) @type expected: list of strings @param expected: The names of the roles that are required. @raise PmmlValidationError: If a role is unrecognized, this method raises an error; otherwise, it silently passes. """ for role in self.xpath("pmml:PlotFormula/@role | pmml:PlotExpression/@role | pmml:PlotNumericExpression/@role"): if role not in expected: raise defs.PmmlValidationError("Unrecognized role: \"%s\" (expected one of \"%s\")" % (role, "\" \"".join(expected)))
def toDict(value): """Convert a CSS style string into a dictionary. @type value: string @param value: A string with the form "name1: value1; name2: value2". @rtype: dict @return: A dictionary that maps style property names to their values. """ if isinstance(value, PlotStyle): value = value._parent.get("style") if isinstance(value, basestring): try: return dict([y.strip() for y in x.split(":")] for x in value.split(";") if x.strip() != "") except ValueError: raise defs.PmmlValidationError("Improperly formatted style string: \"%s\"" % value) return value
def evaluateOne(self, convertType=True): """Evaluate the constant only once, not for every row of a DataColumn. @type convertType: bool @param convertType: If True, convert the type from a string into a Pythonic value. @rtype: string or object @return: Only one copy of the constant. """ try: value = self.fieldType.stringToValue(self.text.strip()) except ValueError as err: raise defs.PmmlValidationError( "Constant \"%s\" cannot be cast as %r: %s" % (self.text.strip(), self.fieldType, str(err))) return value
def _checkFieldTypeX(self, xfieldType): if self.xfieldType is None: self.xfieldType = xfieldType elif self.xfieldType.isstring() and not xfieldType.isstring(): raise defs.PmmlValidationError("Overlaid x plot axis has conflicting types: %r and %r" % (self.xfieldType, xfieldType)) elif self.xfieldType.isstring() and self.xfieldType.optype == "ordinal" and self.xfieldType != xfieldType: raise defs.PmmlValidationError("Overlaid x plot axis has conflicting types: %r and %r" % (self.xfieldType, xfieldType)) elif self.xfieldType.isboolean() and not xfieldType.isboolean(): raise defs.PmmlValidationError("Overlaid x plot axis has conflicting types: %r and %r" % (self.xfieldType, xfieldType)) elif self.xfieldType.isnumeric() and not xfieldType.isnumeric(): raise defs.PmmlValidationError("Overlaid x plot axis has conflicting types: %r and %r" % (self.xfieldType, xfieldType)) elif self.xfieldType.istime() and not xfieldType.istime(): raise defs.PmmlValidationError("Overlaid x plot axis has conflicting types: %r and %r" % (self.xfieldType, xfieldType)) elif (self.xfieldType.isdate() or self.xfieldType.isdatetime()) and not (xfieldType.isdate() or xfieldType.isdatetime()): raise defs.PmmlValidationError("Overlaid x plot axis has conflicting types: %r and %r" % (self.xfieldType, xfieldType))
def where(self, dataTable, functionTable, performanceTable): """Approximate implementation of SQL where using the Formula class. It has a C{between} operator and various other SQL-like methods, but it is not syntactically identical to SQL. See the Formula class for more. @type dataTable: DataTable @param dataTable: The input DataTable, containing any fields that might be used to evaluate this expression. @type functionTable: FunctionTable @param functionTable: The FunctionTable, containing any functions that might be called in this expression. @type performanceTable: PerformanceTable @param performanceTable: A PerformanceTable for measuring the efficiency of the calculation. @rtype: 1d Numpy array of bool @return: The result as a Numpy selector. """ formula = self.get("sqlWhere") if formula is None: return None performanceTable.begin("Aggregate sqlWhere") dataColumn = Formula().evaluate(dataTable, functionTable, performanceTable, formula) if dataColumn.fieldType.dataType != "boolean": raise defs.PmmlValidationError( "Aggregate sqlWhere must evaluate to a boolean expression, not \"%s\"" % formula) dataColumn._unlock() if dataColumn.mask is not None: NP("logical_and", dataColumn.data, NP(dataColumn.mask == defs.VALID), dataColumn.data) performanceTable.end("Aggregate sqlWhere") return dataColumn.data
def functionMax(self, dataColumn, whereMask, groupSelection, getstate, setstate): """Finds the maximum of rows in a DataColumn, possibly with an SQL where mask and groupField. @type dataColumn: DataColumn @param dataColumn: The input data column. @type whereMask: 1d Numpy array of bool, or None @param whereMask: The result of the SQL where selection. @type groupSelection: 1d Numpy array of bool, or None. @param groupSelection: Rows corresponding to a particular value of the groupField. @type getstate: callable function @param getstate: Retrieve staring values from the DataTableState. @type setstate: callable function @param setstate: Store ending values to the DataTableState. @rtype: DataColumn @return: A column of maximized rows. """ fieldType = dataColumn.fieldType if fieldType.optype not in ("continuous", "ordinal"): raise defs.PmmlValidationError( "Aggregate function \"min\" requires a continuous or ordinal input field" ) if dataColumn.mask is None: selection = NP("ones", len(dataColumn), dtype=NP.dtype(bool)) else: selection = NP(dataColumn.mask == defs.VALID) if whereMask is not None: NP("logical_and", selection, whereMask, selection) if groupSelection is not None: NP("logical_and", selection, groupSelection, selection) maximum = None if getstate is not None: startingState = getstate() if startingState is not None: maximum = startingState data = NP("empty", len(dataColumn), dtype=fieldType.dtype) mask = NP("zeros", len(dataColumn), dtype=defs.maskType) for i, x in enumerate(dataColumn.data): if selection[i]: if maximum is None or x > maximum: maximum = x if maximum is None: mask[i] = defs.INVALID else: data[i] = maximum if not mask.any(): mask = None if setstate is not None: setstate(maximum) return DataColumn(fieldType, data, mask)
def _setup(self): if self.optype != "continuous" and len(self.intervals) > 0: raise defs.PmmlValidationError("Non-continuous fields cannot have Intervals") self._displayValue = {} if self.dataType == "object": # for scoring results that don't fit the PMML pattern self.toDataColumn = self._toDataColumn_object self.fromDataColumn = self._fromDataColumn_object self.dtype = NP.dtype(object) self.stringToValue = self._stringToValue_object self.valueToString = self._valueToString_object self.valueToPython = self._valueToPython elif self.dataType == "string": if self.optype == "categorical": self._stringToValue = {} # TODO: merge categorical and ordinal <Value> handling self._valueToString = {} # into _checkValues(data, mask) self._newValuesAllowed = True for value in self.values: v = value.get("value") displayValue = value.get("displayValue") if displayValue is not None: self._displayValue[v] = displayValue if value.get("property", "valid") == "valid": self._addCategorical(v) if len(self._stringToValue) > 0: self._newValuesAllowed = False self.toDataColumn = self._toDataColumn_internal self.fromDataColumn = self._fromDataColumn self.dtype = NP.int64 self.stringToValue = self._stringToValue_categorical self.valueToString = self._valueToString_categorical self.valueToPython = self._valueToString_categorical elif self.optype == "ordinal": self._stringToValue = {} # TODO: see above self._valueToString = {} self._newValuesAllowed = True for value in self.values: v = value.get("value") displayValue = value.get("displayValue") if displayValue is not None: self._displayValue[v] = displayValue if value.get("property", "valid") == "valid": self._addOrdinal(v) self._newValuesAllowed = False self.toDataColumn = self._toDataColumn_internal self.fromDataColumn = self._fromDataColumn self.dtype = NP.dtype(int) self.stringToValue = self._stringToValue_ordinal self.valueToString = self._valueToString_ordinal self.valueToPython = self._valueToString_ordinal elif self.optype == "continuous": self.toDataColumn = self._toDataColumn_string self.fromDataColumn = self._fromDataColumn_object self.dtype = NP.dtype(object) self.stringToValue = self._stringToValue_string self.valueToString = self._valueToString_string self.valueToPython = self._valueToString_string else: raise defs.PmmlValidationError("Unrecognized optype: %s" % self.optype) elif self.dataType == "integer": self.toDataColumn = self._toDataColumn_number self.fromDataColumn = self._fromDataColumn_number self.dtype = NP.dtype(int) self.stringToValue = self._stringToValue_integer self.valueToString = self._valueToString_integer self.valueToPython = self._valueToPython elif self.dataType == "float": self.toDataColumn = self._toDataColumn_number self.fromDataColumn = self._fromDataColumn_number self.dtype = NP.float32 self.stringToValue = self._stringToValue_float self.valueToString = self._valueToString_float self.valueToPython = self._valueToPython elif self.dataType == "double": self.toDataColumn = self._toDataColumn_number self.fromDataColumn = self._fromDataColumn_number self.dtype = NP.dtype(float) self.stringToValue = self._stringToValue_double self.valueToString = self._valueToString_double self.valueToPython = self._valueToPython elif self.dataType == "boolean": self.toDataColumn = self._toDataColumn_number self.fromDataColumn = self._fromDataColumn_number self.dtype = NP.dtype(bool) self.stringToValue = self._stringToValue_boolean self.valueToString = self._valueToString_boolean self.valueToPython = self._valueToPython elif self.dataType == "date": self.toDataColumn = self._toDataColumn_dateTime self.fromDataColumn = self._fromDataColumn self.dtype = NP.int64 self.stringToValue = self._stringToValue_date self.valueToString = self._valueToString_date self.valueToPython = self._valueToPython_date elif self.dataType == "time": self.toDataColumn = self._toDataColumn_dateTime self.fromDataColumn = self._fromDataColumn self.dtype = NP.int64 self.stringToValue = self._stringToValue_time self.valueToString = self._valueToString_time self.valueToPython = self._valueToPython_time elif self.dataType == "dateTime": self.toDataColumn = self._toDataColumn_dateTime self.fromDataColumn = self._fromDataColumn self.dtype = NP.int64 self.stringToValue = self._stringToValue_dateTime self.valueToString = self._valueToString_dateTime self.valueToPython = self._valueToPython_dateTime elif self.dataType == "dateDaysSince[0]": # _offset is the number of seconds between 1/1/1 B.C. and 1/1/1970, using the astronomical convention # that 1 B.C. is "year zero" (which does not exist, even in the proleptic Gregorian calendar) # and that this fictitious year would have been a leap year (366 full days) # http://en.wikipedia.org/wiki/Year_zero#Astronomers self._offset = -62167219200 * self._dateTimeResolution self._factor = 86400 * self._dateTimeResolution # number of microseconds in a day self.toDataColumn = self._toDataColumn_dateTimeNumber self.fromDataColumn = self._fromDataColumn_dateTimeNumber self.dtype = NP.int64 self.stringToValue = self._stringToValue_dateTimeNumber self.valueToString = self._valueToString_dateTimeNumber self.valueToPython = self._valueToPython_dateTimeNumber elif self.dataType == "dateDaysSince[1960]": self._offset = -315619200 * self._dateTimeResolution # number of seconds between 1/1/1960 and 1/1/1970, accounting for leap years/leap seconds self._factor = 86400 * self._dateTimeResolution # number of microseconds in a day self.toDataColumn = self._toDataColumn_dateTimeNumber self.fromDataColumn = self._fromDataColumn_dateTimeNumber self.dtype = NP.int64 self.stringToValue = self._stringToValue_dateTimeNumber self.valueToString = self._valueToString_dateTimeNumber self.valueToPython = self._valueToPython_dateTimeNumber elif self.dataType == "dateDaysSince[1970]": self._offset = 0 self._factor = 86400 * self._dateTimeResolution # number of microseconds in a day self.toDataColumn = self._toDataColumn_dateTimeNumber self.fromDataColumn = self._fromDataColumn_dateTimeNumber self.dtype = NP.int64 self.stringToValue = self._stringToValue_dateTimeNumber self.valueToString = self._valueToString_dateTimeNumber self.valueToPython = self._valueToPython_dateTimeNumber elif self.dataType == "dateDaysSince[1980]": self._offset = 315532800 * self._dateTimeResolution # number of seconds between 1/1/1980 and 1/1/1970, accounting for leap years/leap seconds self._factor = 86400 * self._dateTimeResolution # number of microseconds in a day self.toDataColumn = self._toDataColumn_dateTimeNumber self.fromDataColumn = self._fromDataColumn_dateTimeNumber self.dtype = NP.int64 self.stringToValue = self._stringToValue_dateTimeNumber self.valueToString = self._valueToString_dateTimeNumber self.valueToPython = self._valueToPython_dateTimeNumber elif self.dataType == "timeSeconds": self._offset = 0 self._factor = self._dateTimeResolution # number of microseconds in a second self.toDataColumn = self._toDataColumn_dateTimeNumber self.fromDataColumn = self._fromDataColumn_timeSeconds # reports modulo 1 day self.dtype = NP.int64 self.stringToValue = self._stringToValue_dateTimeNumber self.valueToString = self._valueToString_timeSeconds # reports modulo 1 day self.valueToPython = self._valueToPython_timeSeconds # reports modulo 1 day elif self.dataType == "dateTimeSecondsSince[0]": self._offset = -62167219200 * self._dateTimeResolution # number of seconds between 1/1/1 B.C. and 1/1/1970, accounting for leap years/leap seconds self._factor = self._dateTimeResolution # number of microseconds in a second self.toDataColumn = self._toDataColumn_dateTimeNumber self.fromDataColumn = self._fromDataColumn_dateTimeNumber self.dtype = NP.int64 self.stringToValue = self._stringToValue_dateTimeNumber self.valueToString = self._valueToString_dateTimeNumber self.valueToPython = self._valueToPython_dateTimeNumber elif self.dataType == "dateTimeSecondsSince[1960]": self._offset = -315619200 * self._dateTimeResolution # number of seconds between 1/1/1960 and 1/1/1970, accounting for leap years/leap seconds self._factor = self._dateTimeResolution # number of microseconds in a second self.toDataColumn = self._toDataColumn_dateTimeNumber self.fromDataColumn = self._fromDataColumn_dateTimeNumber self.dtype = NP.int64 self.stringToValue = self._stringToValue_dateTimeNumber self.valueToString = self._valueToString_dateTimeNumber self.valueToPython = self._valueToPython_dateTimeNumber elif self.dataType == "dateTimeSecondsSince[1970]": self._offset = 0 self._factor = self._dateTimeResolution # number of microseconds in a second self.toDataColumn = self._toDataColumn_dateTimeNumber self.fromDataColumn = self._fromDataColumn_dateTimeNumber self.dtype = NP.int64 self.stringToValue = self._stringToValue_dateTimeNumber self.valueToString = self._valueToString_dateTimeNumber self.valueToPython = self._valueToPython_dateTimeNumber elif self.dataType == "dateTimeSecondsSince[1980]": self._offset = 315532800 * self._dateTimeResolution # number of seconds between 1/1/1980 and 1/1/1970, accounting for leap years/leap seconds self._factor = self._dateTimeResolution # number of microseconds in a second self.toDataColumn = self._toDataColumn_dateTimeNumber self.fromDataColumn = self._fromDataColumn_dateTimeNumber self.dtype = NP.int64 self.stringToValue = self._stringToValue_dateTimeNumber self.valueToString = self._valueToString_dateTimeNumber self.valueToPython = self._valueToPython_dateTimeNumber else: raise defs.PmmlValidationError("Unrecognized dataType: %s" % self.dataType) self._hash = hash((self.dataType, self.optype, tuple(self.values), tuple(self.intervals), self.isCyclic))
def prepare(self, state, dataTable, functionTable, performanceTable, plotRange): """Prepare a plot element for drawing. This stage consists of calculating all quantities and determing the bounds of the data. These bounds may be unioned with bounds from other plot elements that overlay this plot element, so the drawing (which requires a finalized coordinate system) cannot begin yet. This method modifies C{plotRange}. @type state: ad-hoc Python object @param state: State information that persists long enough to use quantities computed in C{prepare} in the C{draw} stage. This is a work-around of lxml's refusal to let its Python instances maintain C{self} and it is unrelated to DataTableState. @type dataTable: DataTable @param dataTable: Contains the data to plot. @type functionTable: FunctionTable @param functionTable: Defines functions that may be used to transform data for plotting. @type performanceTable: PerformanceTable @param performanceTable: Measures and records performance (time and memory consumption) of the drawing process. @type plotRange: PlotRange @param plotRange: The bounding box of plot coordinates that this function will expand. """ self.checkRoles([ "y(x)", "dy/dx", "x(t)", "y(t)", "dx/dt", "dy/dt", "x", "y", "dx", "dy" ]) performanceTable.begin("PlotCurve prepare") self._saveContext(dataTable) yofx = self.xpath("pmml:PlotFormula[@role='y(x)']") dydx = self.xpath("pmml:PlotFormula[@role='dy/dx']") xoft = self.xpath("pmml:PlotFormula[@role='x(t)']") yoft = self.xpath("pmml:PlotFormula[@role='y(t)']") dxdt = self.xpath("pmml:PlotFormula[@role='dx/dt']") dydt = self.xpath("pmml:PlotFormula[@role='dy/dt']") nx = self.xpath("pmml:PlotNumericExpression[@role='x']") ny = self.xpath("pmml:PlotNumericExpression[@role='y']") ndx = self.xpath("pmml:PlotNumericExpression[@role='dx']") ndy = self.xpath("pmml:PlotNumericExpression[@role='dy']") cutExpression = self.xpath("pmml:PlotSelection") if len(yofx) + len(dydx) + len(xoft) + len(yoft) + len(dxdt) + len( dydt) > 0: if len(yofx) == 1 and len(dydx) == 0 and len(xoft) == 0 and len( yoft) == 0 and len(dxdt) == 0 and len(dydt) == 0: expression = (yofx[0].text, ) derivative = (None, ) elif len(yofx) == 1 and len(dydx) == 1 and len(xoft) == 0 and len( yoft) == 0 and len(dxdt) == 0 and len(dydt) == 0: expression = (yofx[0].text, ) derivative = (dydx[0].text, ) elif len(yofx) == 0 and len(dydx) == 0 and len(xoft) == 1 and len( yoft) == 1 and len(dxdt) == 0 and len(dydt) == 0: expression = xoft[0].text, yoft[0].text derivative = None, None elif len(yofx) == 0 and len(dydx) == 0 and len(xoft) == 1 and len( yoft) == 1 and len(dxdt) == 1 and len(dydt) == 1: expression = xoft[0].text, yoft[0].text derivative = dxdt[0].text, dydt[0].text else: raise defs.PmmlValidationError( "The only allowed combinations of PlotFormulae are: \"y(x)\", \"y(x) dy/dx\", \"x(t) y(t)\", and \"x(t) y(t) dx/dt dy/dt\"" ) low = self.get("low", convertType=True) high = self.get("high", convertType=True) if low is None or high is None: raise defs.PmmlValidationError( "The \"low\" and \"high\" attributes are required for PlotCurves defined by formulae" ) samples = self.generateSamples(low, high) loop = self.get("loop", defaultFromXsd=True, convertType=True) state.x, state.y, state.dx, state.dy, xfieldType, yfieldType = self.expressionsToPoints( expression, derivative, samples, loop, functionTable, performanceTable) else: performanceTable.pause("PlotCurve prepare") if len(ndx) == 1: dxdataColumn = ndx[0].evaluate(dataTable, functionTable, performanceTable) else: dxdataColumn = None if len(ndy) == 1: dydataColumn = ndy[0].evaluate(dataTable, functionTable, performanceTable) else: dydataColumn = None performanceTable.unpause("PlotCurve prepare") if len(nx) == 0 and len(ny) == 1: performanceTable.pause("PlotCurve prepare") ydataColumn = ny[0].evaluate(dataTable, functionTable, performanceTable) performanceTable.unpause("PlotCurve prepare") if len(cutExpression) == 1: performanceTable.pause("PlotCurve prepare") selection = cutExpression[0].select( dataTable, functionTable, performanceTable) performanceTable.unpause("PlotCurve prepare") else: selection = NP("ones", len(ydataColumn.data), NP.dtype(bool)) if ydataColumn.mask is not None: selection = NP("logical_and", selection, NP(ydataColumn.mask == defs.VALID), selection) if dxdataColumn is not None and dxdataColumn.mask is not None: selection = NP("logical_and", selection, NP(dxdataColumn.mask == defs.VALID), selection) if dydataColumn is not None and dydataColumn.mask is not None: selection = NP("logical_and", selection, NP(dydataColumn.mask == defs.VALID), selection) yarray = ydataColumn.data[selection] xarray = NP("ones", len(yarray), dtype=NP.dtype(float)) xarray[0] = 0.0 xarray = NP("cumsum", xarray) dxarray, dyarray = None, None if dxdataColumn is not None: dxarray = dxdataColumn.data[selection] if dydataColumn is not None: dyarray = dydataColumn.data[selection] xfieldType = self.xfieldType yfieldType = ydataColumn.fieldType elif len(nx) == 1 and len(ny) == 1: performanceTable.pause("PlotCurve prepare") xdataColumn = nx[0].evaluate(dataTable, functionTable, performanceTable) ydataColumn = ny[0].evaluate(dataTable, functionTable, performanceTable) performanceTable.unpause("PlotCurve prepare") if len(cutExpression) == 1: performanceTable.pause("PlotCurve prepare") selection = cutExpression[0].select( dataTable, functionTable, performanceTable) performanceTable.unpause("PlotCurve prepare") else: selection = NP("ones", len(ydataColumn.data), NP.dtype(bool)) if xdataColumn.mask is not None: selection = NP("logical_and", selection, NP(xdataColumn.mask == defs.VALID), selection) if ydataColumn.mask is not None: selection = NP("logical_and", selection, NP(ydataColumn.mask == defs.VALID), selection) if dxdataColumn is not None and dxdataColumn.mask is not None: selection = NP("logical_and", selection, NP(dxdataColumn.mask == defs.VALID), selection) if dydataColumn is not None and dydataColumn.mask is not None: selection = NP("logical_and", selection, NP(dydataColumn.mask == defs.VALID), selection) xarray = xdataColumn.data[selection] yarray = ydataColumn.data[selection] dxarray, dyarray = None, None if dxdataColumn is not None: dxarray = dxdataColumn.data[selection] if dydataColumn is not None: dyarray = dydataColumn.data[selection] xfieldType = xdataColumn.fieldType yfieldType = ydataColumn.fieldType else: raise defs.PmmlValidationError( "The only allowed combinations of PlotNumericExpressions are: \"y(x)\" and \"x(t) y(t)\"" ) persistentState = {} stateId = self.get("stateId") if stateId is not None: if stateId in dataTable.state: persistentState = dataTable.state[stateId] xarray = NP("concatenate", [xarray, persistentState["x"]]) yarray = NP("concatenate", [yarray, persistentState["y"]]) if dxarray is not None: dxarray = NP("concatenate", [dxarray, persistentState["dx"]]) if dyarray is not None: dyarray = NP("concatenate", [dyarray, persistentState["dy"]]) else: dataTable.state[stateId] = persistentState persistentState["x"] = xarray persistentState["y"] = yarray if dxarray is not None: persistentState["dx"] = dxarray if dyarray is not None: persistentState["dy"] = dyarray smooth = self.get("smooth", defaultFromXsd=True, convertType=True) if not smooth: if dyarray is not None and dxarray is None: dxarray = NP( (NP("roll", xarray, -1) - NP("roll", xarray, 1)) / 2.0) dyarray = dyarray * dxarray loop = self.get("loop", defaultFromXsd=True, convertType=True) if dxarray is not None and not loop: dxarray[0] = 0.0 dxarray[-1] = 0.0 if dyarray is not None and not loop: dyarray[0] = 0.0 dyarray[-1] = 0.0 state.x = xarray state.y = yarray state.dx = dxarray state.dy = dyarray else: smoothingScale = self.get("smoothingScale", defaultFromXsd=True, convertType=True) loop = self.get("loop", defaultFromXsd=True, convertType=True) samples = self.generateSamples(xarray.min(), xarray.max()) state.x, state.y, state.dx, state.dy = self.pointsToSmoothCurve( xarray, yarray, samples, smoothingScale, loop) if plotRange is not None: plotRange.expand(state.x, state.y, xfieldType, yfieldType) performanceTable.end("PlotCurve prepare")
def expressionsToPoints(cls, expression, derivative, samples, loop, functionTable, performanceTable): """Evaluate a set of given string-based formulae to generate numeric points. This is used to plot mathematical curves. @type expression: 1- or 2-tuple of strings @param expression: If a 1-tuple, the string is passed to Formula and interpreted as y(x); if a 2-tuple, the strings are passed to Formula and interpreted as x(t), y(t). @type derivative: 1- or 2-tuple of strings (same length as C{expression}) @param derivative: Strings are passed to Formua and interpreted as dy/dx (if a 1-tuple) or dx/dt, dy/dt (if a 2-tuple). @type samples: 1d Numpy array @param samples: Values of x or t at which to evaluate the expression or expressions. @type loop: bool @param loop: If False, disconnect the end of the set of points from the beginning. @type functionTable: FunctionTable @param functionTable: Functions that may be used to perform the calculation. @type performanceTable: PerformanceTable @param performanceTable: Measures and records performance (time and memory consumption) of the process. @rtype: 6-tuple @return: C{xlist}, C{ylist}, C{dxlist}, C{dylist} (1d Numpy arrays), xfieldType, yfieldType (FieldTypes). """ if len(expression) == 1: sampleTable = DataTable({"x": "double"}, {"x": samples}) parsed = Formula.parse(expression[0]) ydataColumn = parsed.evaluate(sampleTable, functionTable, performanceTable) if not ydataColumn.fieldType.isnumeric( ) and not ydataColumn.fieldType.istemporal(): raise defs.PmmlValidationError( "PlotFormula y(x) must return a numeric expression, not %r" % ydataColumn.fieldType) xfieldType = cls.xfieldType yfieldType = ydataColumn.fieldType selection = None if ydataColumn.mask is not None: selection = NP(ydataColumn.mask == defs.VALID) if derivative[0] is None: if selection is None: xlist = samples ylist = ydataColumn.data else: xlist = samples[selection] ylist = ydataColumn.data[selection] dxlist = NP( (NP("roll", xlist, -1) - NP("roll", xlist, 1)) / 2.0) dylist = NP( (NP("roll", ylist, -1) - NP("roll", ylist, 1)) / 2.0) if not loop: dxlist[0] = 0.0 dxlist[-1] = 0.0 dylist[0] = 0.0 dylist[-1] = 0.0 else: parsed = Formula.parse(derivative[0]) dydataColumn = parsed.evaluate(sampleTable, functionTable, performanceTable) if not dydataColumn.fieldType.isnumeric( ) and not dydataColumn.fieldType.istemporal(): raise defs.PmmlValidationError( "PlotFormula dy/dx must return a numeric expression, not %r" % dydataColumn.fieldType) if dydataColumn.mask is not None: if selection is None: selection = NP(dydataColumn.mask == defs.VALID) else: NP("logical_and", selection, NP(dydataColumn.mask == defs.VALID), selection) if selection is None: xlist = samples ylist = ydataColumn.data dxlist = NP( (NP("roll", xlist, -1) - NP("roll", xlist, 1)) / 2.0) dylist = dydataColumn.data else: xlist = samples[selection] ylist = ydataColumn.data[selection] dxlist = NP( (NP("roll", xlist, -1) - NP("roll", xlist, 1)) / 2.0) dylist = NP(dydataColumn.data[selection] * dxlist) if not loop: dxlist[0] = 0.0 dxlist[-1] = 0.0 dylist[0] = 0.0 dylist[-1] = 0.0 elif len(expression) == 2: sampleTable = DataTable({"t": "double"}, {"t": samples}) parsed = Formula.parse(expression[0]) xdataColumn = parsed.evaluate(sampleTable, functionTable, performanceTable) if not xdataColumn.fieldType.isnumeric( ) and not xdataColumn.fieldType.istemporal(): raise defs.PmmlValidationError( "PlotFormula x(t) must return a numeric expression, not %r" % xdataColumn.fieldType) parsed = Formula.parse(expression[1]) ydataColumn = parsed.evaluate(sampleTable, functionTable, performanceTable) if not ydataColumn.fieldType.isnumeric( ) and not ydataColumn.fieldType.istemporal(): raise defs.PmmlValidationError( "PlotFormula y(t) must return a numeric expression, not %r" % ydataColumn.fieldType) xfieldType = xdataColumn.fieldType yfieldType = ydataColumn.fieldType selection = None if xdataColumn.mask is not None: selection = NP(xdataColumn.mask == defs.VALID) if ydataColumn.mask is not None: if selection is None: selection = NP(ydataColumn.mask == defs.VALID) else: NP("logical_and", selection, NP(ydataColumn.mask == defs.VALID), selection) if derivative[0] is None: if selection is None: xlist = xdataColumn.data ylist = ydataColumn.data else: xlist = xdataColumn.data[selection] ylist = ydataColumn.data[selection] dxlist = NP( (NP("roll", xlist, -1) - NP("roll", xlist, 1)) / 2.0) dylist = NP( (NP("roll", ylist, -1) - NP("roll", ylist, 1)) / 2.0) if not loop: dxlist[0] = 0.0 dxlist[-1] = 0.0 dylist[0] = 0.0 dylist[-1] = 0.0 else: parsed = Formula.parse(derivative[0]) dxdataColumn = parsed.evaluate(sampleTable, functionTable, performanceTable) if not dxdataColumn.fieldType.isnumeric( ) and not dxdataColumn.fieldType.istemporal(): raise defs.PmmlValidationError( "PlotFormula dx/dt must return a numeric expression, not %r" % dxdataColumn.fieldType) parsed = Formula.parse(derivative[1]) dydataColumn = parsed.evaluate(sampleTable, functionTable, performanceTable) if not dydataColumn.fieldType.isnumeric( ) and not dydataColumn.fieldType.istemporal(): raise defs.PmmlValidationError( "PlotFormula dy/dt must return a numeric expression, not %r" % dydataColumn.fieldType) if dxdataColumn.mask is not None: if selection is None: selection = NP(dxdataColumn.mask == defs.VALID) else: NP("logical_and", selection, NP(dxdataColumn.mask == defs.VALID), selection) if dydataColumn.mask is not None: if selection is None: selection = NP(dydataColumn.mask == defs.VALID) else: NP("logical_and", selection, NP(dydataColumn.mask == defs.VALID), selection) if selection is None: dt = NP( (NP("roll", samples, -1) - NP("roll", samples, 1)) / 2.0) xlist = xdataColumn.data ylist = ydataColumn.data dxlist = NP(dxdataColumn.data * dt) dylist = NP(dydataColumn.data * dt) else: dt = NP((NP("roll", samples[selection], -1) - NP("roll", samples[selection], 1)) / 2.0) xlist = xdataColumn.data[selection] ylist = ydataColumn.data[selection] dxlist = NP(dxdataColumn.data[selection] * dt) dylist = NP(dydataColumn.data[selection] * dt) if not loop: dxlist[0] = 0.0 dxlist[-1] = 0.0 dylist[0] = 0.0 dylist[-1] = 0.0 return xlist, ylist, dxlist, dylist, xfieldType, yfieldType