예제 #1
0
파일: cart.py 프로젝트: bwengals/hadrian
    def pfaValueType(self, dataType):
        """Create a PFA type schema representing the comparison values."""

        if dataType.get("type", None) != "record":
            raise TypeError("dataType must be a record")

        dataTypeName = dataType["name"]
        dataFieldNames = [x["name"] for x in dataType["fields"]]

        if not set(self.dataset.names).issubset(set(dataFieldNames)):
            raise TypeError("dataType must be a record with at least as many fields as the ones used to train the tree")

        dataFieldTypes = []
        for field in dataType["fields"]:
            try:
                fieldIndex = self.dataset.names.index(field["name"])
            except ValueError:
                pass
            else:
                if self.dataset.fields[fieldIndex].tpe == numbers.Real:
                    if field["type"] not in ("int", "long", "float", "double"):
                        raise TypeError("dataType field \"{0}\" must be a numeric type, since this was a numeric type in the dataset training".format(field["name"]))
                    dataFieldTypes.append(field["type"])
                elif self.dataset.fields[fieldIndex].tpe == basestring:
                    if field["type"] != "string":
                        raise TypeError("dataType field \"{0}\" must be a string, since this was a string in the dataset training".format(field["name"]))
                    if self.maxSubsetSize == 1:
                        dataFieldTypes.append(field["type"])
                    else:
                        dataFieldTypes.append({"type": "array", "items": field["type"]})

        asjson = [json.dumps(x) for x in dataFieldTypes]
        astypes = ForwardDeclarationParser().parse(asjson)
        return LabelData.broadestType(astypes.values())
예제 #2
0
파일: cart.py 프로젝트: Marigold/hadrian
    def pfaValueType(self, dataType):
        """Create an Avro schema representing the comparison value type.

        :type dataType: Pythonized JSON
        :param dataType: Avro record schema of the input data
        :rtype: Pythonized JSON
        :return: value type (``value`` field of the PFA ``TreeNode``)
        """

        if dataType.get("type", None) != "record":
            raise TypeError("dataType must be a record")

        dataTypeName = dataType["name"]
        dataFieldNames = [x["name"] for x in dataType["fields"]]

        if not set(self.dataset.names).issubset(set(dataFieldNames)):
            raise TypeError(
                "dataType must be a record with at least as many fields as the ones used to train the tree"
            )

        dataFieldTypes = []
        for field in dataType["fields"]:
            try:
                fieldIndex = self.dataset.names.index(field["name"])
            except ValueError:
                pass
            else:
                if self.dataset.fields[fieldIndex].tpe == numbers.Real:
                    if field["type"] not in ("int", "long", "float", "double"):
                        raise TypeError(
                            "dataType field \"{0}\" must be a numeric type, since this was a numeric type in the dataset training"
                            .format(field["name"]))
                    dataFieldTypes.append(field["type"])
                elif self.dataset.fields[fieldIndex].tpe == str:
                    if field["type"] != "string":
                        raise TypeError(
                            "dataType field \"{0}\" must be a string, since this was a string in the dataset training"
                            .format(field["name"]))
                    if self.maxSubsetSize == 1:
                        dataFieldTypes.append(field["type"])
                    else:
                        dataFieldTypes.append({
                            "type": "array",
                            "items": field["type"]
                        })

        asjson = [json.dumps(x) for x in dataFieldTypes]
        astypes = ForwardDeclarationParser().parse(asjson)
        return LabelData.broadestType(astypes.values())
예제 #3
0
 def broadestType(self, types):
     return LabelData.broadestType(types)