コード例 #1
0
ファイル: arboreal_tree.py プロジェクト: stitchfix/arboreal
    def predict_proba(self, *args):
        # predict_proba conforms to the sklearn estimator interface

        # predict_proba is only available for classifiers
        m = self.root.data[NODE_DATASET_KEY].metadata
        assert m.is_categorical(m.target)
        assert self._estimator_type == ESTIMATOR_TYPE_CLASSIFIER

        # if being called with a dataframe so we convert the dataframe to a
        # Dataset before calling transform on it
        assert len(args) == 1 and isinstance(args[0], pd.DataFrame)
        X = args[0]
        m = self.root.data[NODE_DATASET_KEY].metadata
        dataset = Dataset.from_pandas_X(X, m)

        # gather predictions and probabilities in the format sklearn expects
        all_datapoints_class_probabilities = []
        for transformed in self.transform(dataset):
            targets_reduced, total_datapoints_count, prediction_datatype = (
                transformed)
            # for each class in order, get the proba
            class_probabilities = []
            for klass in self.classes_:
                fraction_of_class = targets_reduced[klass] / sum(
                    targets_reduced.values())
                class_probabilities.append(fraction_of_class)
            all_datapoints_class_probabilities.append(class_probabilities)

        return np.array(all_datapoints_class_probabilities)
コード例 #2
0
ファイル: arboreal_tree.py プロジェクト: stitchfix/arboreal
    def predict(self, *args):
        assert len(args) == 1

        if isinstance(args[0], pd.DataFrame):
            X = args[0]
            m = self.root.data[NODE_DATASET_KEY].metadata
            dataset = Dataset.from_pandas_X(X, m)
        elif isinstance(args[0], Dataset):
            dataset = args[0]
        else:
            raise ValueError(
                "Arg to predict() must be a Pandas DataFrame or Arboreal Dataset"
            )

        # return predictions in the format sklearn expects
        predictions = []
        for transformed in self.transform(dataset):
            targets_reduced, total_datapoints_count, prediction_datatype = (
                transformed)
            if prediction_datatype == Datatype.numerical:
                prediction = targets_reduced
            elif prediction_datatype == Datatype.categorical:
                prediction = targets_reduced.most_common()[0][0]
            predictions.append(prediction)

        return predictions
コード例 #3
0
ファイル: arboreal_tree.py プロジェクト: stitchfix/arboreal
    def transform(self, *args):
        # This function aims to be compatible with both arboreal Datasets
        # directly as well as Pandas dataframes, to implement the sklearn
        # interface.  If a DataFrame is passed in, we create a Dataset, and
        # then continue handling as usual.  Note, if we need to create a
        # dataset, rather than do inference on the types of the passed-in
        # dataframe, we use the metadata determined in fit() (called before
        # transform()) to determine which columns are of which types.
        if len(args) == 1 and isinstance(args[0], Dataset):
            dataset = args[0]
        elif len(args) == 1 and isinstance(args[0], pd.DataFrame):
            X = args[0]
            # use the fit metadata rather than re-inferring (and potentially
            # being inconsistent with the previous inference)
            m = self.root.data[NODE_DATASET_KEY].metadata
            dataset = Dataset.from_pandas_X(X, m)
        else:
            raise ValueError(
                "transform requires either an Arboreal Dataset or Pandas DataFrame"
            )

        return self._transform(dataset)