class MarginalXGBClassifier(MarginalClassifier):
    def __init__(self,
                 method='softmax',
                 calibrator=LogitCalibrator,
                 MAX_LR=10):
        if method == 'softmax':
            self._classifier = XGBClassifier(class_weight='balanced')
        elif method == 'sigmoid':
            self._classifier = OneVsRestClassifier(
                XGBClassifier(class_weight='balanced'))
        self.method = method
        self._calibrator = calibrator
        self._calibrators_per_target_class = {}
        self.MAX_LR = MAX_LR

    def fit_classifier(self, X, y):
        self._classifier.fit(X, y)

        if self.method == 'softmax':
            self.n_trees = len(self._classifier.get_booster().get_dump())
            # import matplotlib.pyplot as plt
            # from xgboost import plot_tree
            # import graphviz

            # plot_tree(self._classifier, num_trees=0)
            # plt.show()

        elif self.method == 'sigmoid':
            self.n_trees = len(
                self._classifier._first_estimator.get_booster().get_dump())