예제 #1
0
def compare_to_actual(model,
                      X,
                      T,
                      E,
                      times=None,
                      is_at_risk=False,
                      figure_size=(16, 6),
                      metrics=['rmse', 'mean', 'median'],
                      **kwargs):
    """
    Comparing the actual and predicted number of units at risk and units 
    experiencing an event at each time t.

    Parameters:
    -----------
    * model : pysurvival model
        The model that will be used for prediction

    * X : array-like, shape=(n_samples, n_features)
        The input samples.

    * T : array-like, shape = [n_samples] 
        The target values describing when the event of interest or censoring
        occured

    * E : array-like, shape = [n_samples] 
        The Event indicator array such that E = 1. if the event occured
        E = 0. if censoring occured

    * times: array-like, (default=None)
        A vector of timepoints.

    * is_at_risk: bool (default=True)
        Whether the function returns Expected number of units at risk
        or the Expected number of units experiencing the events.

    * figure_size: tuple of double (default= (16, 6))
        width, height in inches representing the size of the chart 

    * metrics: str or list of str (default='all')
        Indicates the performance metrics to compute:
            - if None, then no metric is computed
            - if str, then the metric is computed
            - if list of str, then the metrics are computed

        The available metrics are:
            - RMSE: root mean squared error
            - Mean Abs Error: mean absolute error
            - Median Abs Error: median absolute error

    Returns:
    --------
    * results: float or dict
        Performance metrics   

    """

    # Initializing the Kaplan-Meier model
    X, T, E = utils.check_data(X, T, E)
    kmf = KaplanMeierModel()
    kmf.fit(T, E)

    # Creating actual vs predicted
    N = T.shape[0]

    # Defining the time axis
    if times is None:
        times = kmf.times

    # Number of Expected number of units at risk
    # or the Expected number of units experiencing the events
    actual = []
    actual_upper = []
    actual_lower = []
    predicted = []
    if is_at_risk:
        model_predicted = np.sum(model.predict_survival(X, **kwargs), 0)

        for t in times:
            min_index = [abs(a_j_1 - t) for (a_j_1, a_j) in model.time_buckets]
            index = np.argmin(min_index)
            actual.append(N * kmf.predict_survival(None, t))
            actual_upper.append(N * kmf.predict_survival_upper(None, t))
            actual_lower.append(N * kmf.predict_survival_lower(None, t))
            predicted.append(model_predicted[index])

    else:
        model_predicted = np.sum(model.predict_density(X, **kwargs), 0)

        for t in times:
            min_index = [abs(a_j_1 - t) for (a_j_1, a_j) in model.time_buckets]
            index = np.argmin(min_index)
            actual.append(N * kmf.predict_density(None, t))
            h = kmf.predict_hazard(None, t)
            actual_upper.append(N * kmf.predict_survival_upper(None, t) * h)
            actual_lower.append(N * kmf.predict_survival_lower(None, t) * h)
            predicted.append(model_predicted[index])

    # Computing the performance metrics
    results = None
    title = 'Actual vs Predicted'
    if metrics is not None:

        # RMSE
        rmse = np.sqrt(mean_squared_error(actual, predicted))

        # Median Abs Error
        med_ae = median_absolute_error(actual, predicted)

        # Mean Abs Error
        mae = mean_absolute_error(actual, predicted)

        if isinstance(metrics, str):

            # RMSE
            if 'rmse' in metrics.lower() or 'root' in metrics.lower():
                results = rmse
                title += "\n"
                title += "RMSE = {:.3f}".format(rmse)

            # Median Abs Error
            elif 'median' in metrics.lower():
                results = med_ae
                title += "\n"
                title += "Median Abs Error = {:.3f}".format(med_ae)

            # Mean Abs Error
            elif 'mean' in metrics.lower():
                results = mae
                title += "\n"
                title += "Mean Abs Error = {:.3f}".format(mae)

            else:
                raise NotImplementedError(
                    '{} is not a valid metric function.'.format(metrics))

        elif isinstance(metrics, list) or isinstance(metrics, numpy.ndarray):
            results = {}

            # RMSE
            is_rmse = False
            if any([('rmse' in m.lower() or 'root' in m.lower()) \
                    for m in metrics]):
                is_rmse = True
                results['root_mean_squared_error'] = rmse
                title += "\n"
                title += "RMSE = {:.3f}".format(rmse)

            # Median Abs Error
            is_med_ae = False
            if any(['median' in m.lower() for m in metrics]):
                is_med_ae = True
                results['median_absolute_error'] = med_ae
                title += "\n"
                title += "Median Abs Error = {:.3f}".format(med_ae)

            # Mean Abs Error
            is_mae = False
            if any(['mean' in m.lower() for m in metrics]):
                is_mae = True
                results['mean_absolute_error'] = mae
                title += "\n"
                title += "Mean Abs Error = {:.3f}".format(mae)

            if all([not is_mae, not is_rmse, not is_med_ae]):
                error = 'The provided metrics are not available.'
                raise NotImplementedError(error)

    # Plotting
    fig, ax = plt.subplots(figsize=figure_size)
    ax.plot(times, actual, color='red', label='Actual', alpha=0.8, lw=3)
    ax.plot(times, predicted, color='blue', label='Predicted', alpha=0.8, lw=3)
    plt.xlim(0, max(T))

    # Filling the areas between the Survival and Confidence Intervals curves
    plt.fill_between(times,
                     actual,
                     actual_lower,
                     label='Confidence Intervals - Lower',
                     color='red',
                     alpha=0.2)
    plt.fill_between(times,
                     actual,
                     actual_upper,
                     label='Confidence Intervals - Upper',
                     color='red',
                     alpha=0.2)

    # Finalizing the chart
    plt.title(title, fontsize=15)
    plt.legend(fontsize=15)
    plt.show()

    return results
예제 #2
0
telcom_data.Churn = le_churn.fit_transform(telcom_data.Churn)

#telcom_data.Churn[telcom_data.Churn == 'Yes'] = 1
#telcom_data.Churn[telcom_data.Churn == 'No'] = 0

#telcom_data.gender[telcom_data.gender == 'Male'] = 1
#telcom_data.gender[telcom_data.gender == 'Female'] = 0


T_male = telcom_data[telcom_data.gender == 1].tenure
E_male = telcom_data[telcom_data.gender== 1].Churn

T_female = telcom_data[telcom_data.gender == 0].tenure
E_female = telcom_data[telcom_data.gender == 0].Churn

km_male_model = KaplanMeierModel()
km_male_model.fit(T_male, E_male, alpha=0.95)

km_female_model = KaplanMeierModel()
km_female_model.fit(T_female, E_female, alpha=0.95)

#display_non_parametric(km_male_model)
plt.plot(km_female_model.times, km_female_model.survival,label='Female')
plt.plot(km_male_model.times, km_male_model.survival,label='Male')
plt.xlabel('Tenure - Months')
plt.ylabel('Probability of Survival')
plt.title('Kaplan-Meier Survival by Gender')
plt.legend()
plt.show()

T_bank_transfer = telcom_data[telcom_data.PaymentMethod == 0].tenure
예제 #3
0
def load_model(path_file):
    """ Load the model and its parameters from a .zip file 

    Parameters:
    -----------
    * path_file, str
        address of the file where the model will be loaded from 

    Returns:
    --------
    * pysurvival_model : Pysurvival object
        Pysurvival model
    """

    # Initializing a base model
    from pysurvival.models import BaseModel
    base_model = BaseModel()

    # Temporary loading the model
    base_model.load(path_file)
    model_name = base_model.name

    # Loading the actual Pysurvival model - Kaplan-Meier
    if 'kaplanmeier' in model_name.lower():

        if 'smooth' in model_name.lower():
            from pysurvival.models.non_parametric import SmoothKaplanMeierModel
            pysurvival_model = SmoothKaplanMeierModel()

        else:
            from pysurvival.models.non_parametric import KaplanMeierModel
            pysurvival_model = KaplanMeierModel()

    elif 'linearmultitask' in model_name.lower():

        from pysurvival.models.multi_task import LinearMultiTaskModel
        pysurvival_model = LinearMultiTaskModel()

    elif 'neuralmultitask' in model_name.lower():

        from pysurvival.models.multi_task import NeuralMultiTaskModel
        structure = [
            {
                'activation': 'relu',
                'num_units': 128
            },
        ]
        pysurvival_model = NeuralMultiTaskModel(structure=structure)

    elif 'exponential' in model_name.lower():

        from pysurvival.models.parametric import ExponentialModel
        pysurvival_model = ExponentialModel()

    elif 'weibull' in model_name.lower():

        from pysurvival.models.parametric import WeibullModel
        pysurvival_model = WeibullModel()

    elif 'gompertz' in model_name.lower():

        from pysurvival.models.parametric import GompertzModel
        pysurvival_model = GompertzModel()

    elif 'loglogistic' in model_name.lower():

        from pysurvival.models.parametric import LogLogisticModel
        pysurvival_model = LogLogisticModel()

    elif 'lognormal' in model_name.lower():

        from pysurvival.models.parametric import LogNormalModel
        pysurvival_model = LogNormalModel()

    elif 'simulation' in model_name.lower():

        from pysurvival.models.simulations import SimulationModel
        pysurvival_model = SimulationModel()

    elif 'coxph' in model_name.lower():

        if 'nonlinear' in model_name.lower():
            from pysurvival.models.semi_parametric import NonLinearCoxPHModel
            pysurvival_model = NonLinearCoxPHModel()

        else:
            from pysurvival.models.semi_parametric import CoxPHModel
            pysurvival_model = CoxPHModel()

    elif 'random' in model_name.lower() and 'survival' in model_name.lower():

        from pysurvival.models.survival_forest import RandomSurvivalForestModel
        pysurvival_model = RandomSurvivalForestModel()

    elif 'extra' in model_name.lower() and 'survival' in model_name.lower():

        from pysurvival.models.survival_forest import ExtraSurvivalTreesModel
        pysurvival_model = ExtraSurvivalTreesModel()

    elif 'condi' in model_name.lower() and 'survival' in model_name.lower():

        from pysurvival.models.survival_forest import ConditionalSurvivalForestModel
        pysurvival_model = ConditionalSurvivalForestModel()

    elif 'svm' in model_name.lower():

        if 'linear' in model_name.lower():

            from pysurvival.models.svm import LinearSVMModel
            pysurvival_model = LinearSVMModel()

        elif 'kernel' in model_name.lower():

            from pysurvival.models.svm import KernelSVMModel
            pysurvival_model = KernelSVMModel()

    else:
        raise NotImplementedError(
            '{} is not a valid pysurvival model.'.format(model_name))

    # Transferring the components
    pysurvival_model.__dict__.update(copy.deepcopy(base_model.__dict__))
    del base_model

    return pysurvival_model