def get_scores(model, y_test, delta_test, time_grid, surv_residual = False, cens_residual = False): n = y_test.shape[0] x_train, target = model.training_data y_train, delta_train = target # compute residual from training data exp_residual_train = np.nan_to_num(np.exp(np.log(y_train) - model.predict(x_train).reshape(-1))) exp_residual_test = np.nan_to_num(np.exp(np.log(y_test) - model.predict(x_test).reshape(-1))) # compute exp(-theta) from test data to evaluate accelerating component exp_predict_neg_test = np.nan_to_num(np.exp(-model.predict(x_test)).reshape(-1)) naf_base = NelsonAalenFitter().fit(y_train, event_observed = delta_train) kmf_cens = KaplanMeierFitter().fit(y_train, event_observed = 1 - delta_train) if cens_residual == True: cens_test = kmf_cens.survival_function_at_times(exp_residual_test) elif cens_residual == False: cens_test = kmf_cens.survival_function_at_times(y_test) bss = [] nblls = [] for t in time_grid: bs, nbll = get_score(n, t, y_test, delta_test, naf_base, kmf_cens, cens_test, exp_predict_neg_test, surv_residual, cens_residual, model) bss.append(bs) nblls.append(-nbll) return (np.array(bss), np.array(nblls))
def integrated_brier_score(time_true: np.ndarray, time_pred: np.ndarray, event_observed: np.ndarray, time_bins: np.ndarray) -> float: r"""Compute the integrated Brier score for a predicted survival function. The integrated Brier score is defined as the mean squared error between the true and predicted survival functions at time t, integrated over all timepoints. Parameters ---------- time_true : np.ndarray, shape=(n_samples,) The true time to event or censoring for each sample. time_pred : np.ndarray, shape=(n_samples, n_time_bins) The predicted survival probabilities for each sample in each time bin. event_observed : np.ndarray, shape=(n_samples,) The event indicator for each sample (1 = event, 0 = censoring). time_bins : np.ndarray, shape=(n_time_bins,) The time bins for which the survival function was computed. Returns ------- float The integrated Brier score of the predictions. Notes ----- This function uses the definition from [1]_ with inverse probability of censoring weighting (IPCW) to correct for censored observations. The weights are computed using the Kaplan-Meier estimate of the censoring distribution. References ---------- .. [1] E. Graf, C. Schmoor, W. Sauerbrei, and M. Schumacher, ‘Assessment and comparison of prognostic classification schemes for survival data’, Statistics in Medicine, vol. 18, no. 17‐18, pp. 2529–2545, Sep. 1999. """ # compute weights for inverse probability of censoring weighting (IPCW) censoring_km = KaplanMeierFitter() censoring_km.fit(time_true, 1 - event_observed) weights_event = censoring_km.survival_function_at_times( time_true).values.reshape(-1, 1) weights_no_event = censoring_km.survival_function_at_times( time_bins).values.reshape(1, -1) # scores for subjects with event before time t for each time bin had_event = (time_true[:, np.newaxis] <= time_bins) & event_observed[:, np.newaxis] scores_event = np.where(had_event, (0 - time_pred)**2 / weights_event, 0) # scores for subjects with no event and no censoring before time t for each time bin scores_no_event = np.where((time_true[:, np.newaxis] > time_bins), (1 - time_pred)**2 / weights_no_event, 0) scores = np.mean(scores_event + scores_no_event, axis=0) # integrate over all time bins score = np.trapz(scores, time_bins) / time_bins.max() return score
def signedByMedianSurvival( T: pd.Series, E: pd.Series, mask: pd.Series, *, alternative_mask: Union[None, pd.Series] = None, timeline: Union[None, Sequence] = None, ) -> int: """ Decide if group defined by mask has a better (+1) or worse (-1) prognosis. Since typically, mask defines low expression, +1 means accelerating disease. Parameters ---------- T A series of survival times. E A series of events, where 1 is the event (death). mask A Pandas mask for the main group. alternative_mask The second group is the negation of the first mask by default. This parameter sets a custom mask. timeline A series of time points (days in TCGA) when to sample survival probs. Returns ------- Sign of the survival benefit for the grouping mask. """ if alternative_mask is None: alternative_mask = ~mask if timeline is None: timeline = [0, 1500, 3000, 4500, 6000, 7500, 9000] kmf1 = KaplanMeierFitter() kmf1.fit( T[mask], E[mask], ) kmf2 = KaplanMeierFitter() kmf2.fit( T[alternative_mask], E[alternative_mask], ) if np.trapz(kmf1.survival_function_at_times(timeline)) > np.trapz( kmf2.survival_function_at_times(timeline) ): return 1 else: return -1
def show_survival_curve(df, t_col, y_col, max_time=None, weight=None, save_file=None): plt.figure(figsize=(8, 6)) plt.rcParams["font.size"] = 14 colors = ['blue', 'red', 'magenta'] tr_uniq = np.sort(df[t_col].astype(int).unique()) max_time = df[y_col].max() if max_time is None else max_time time = df[y_col].values event = np.where(df[y_col] < max_time, 1, 0) verbose_days = [ 0, int((max_time - 1) / 3), int((max_time - 1) * 2 / 3), int(max_time) - 1 ] for d in verbose_days: plt.text(d, 0.6, f'RR({d}day)', horizontalalignment='center', verticalalignment='center') curve_list = [] elapsed_days = np.array([i for i in range(int(max_time))]) kmf = KaplanMeierFitter() for i, tr in enumerate(tr_uniq): t_idx = (df[t_col] == tr) if weight is None: kmf.fit(time[t_idx], event[t_idx], label=f'tr={tr}') else: kmf.fit(time[t_idx], event[t_idx], label=f'tr={tr}', weights=weight[t_idx]) curve_list.append(kmf.survival_function_at_times(elapsed_days)) ax = kmf.plot(c=colors[i]) for d in verbose_days: surv_prob = kmf.survival_function_at_times(d).values[0] ax = plt.scatter(d, surv_prob, marker='o', c=colors[i]) ax = plt.text(d, 0.6 - 0.02 * (i + 1), f'{surv_prob :.3f}', c=colors[i], horizontalalignment='center', verticalalignment='center') plt.xlim(-3, int(max_time) + 3) plt.ylim(0.5, 1.05) plt.xlabel('Followed days (elapsed days)') plt.ylabel('Survival probability (retention rate)') plt.legend(loc='best') plt.grid() plt.tight_layout() if save_file is not None: plt.savefig(save_file) plt.show() return (np.array(curve_list[1]) - np.array(curve_list[0])).reshape(-1)
def index_of_survival(request: HttpRequest, all_parameter: str): """ response = { data = } """ mm = all_parameter.split("&") st = mm[0].split("=")[1] if ("," in mm[1].split("=")[1]): ct = mm[1].split("=")[1].split(",") else: ct = [mm[1].split("=")[1]] b = API.DatabaseAPI("tcga") my_dict_b = b.query_collection_obs() my_df_b = pd.DataFrame(my_dict_b) select_part = my_df_b.loc[my_df_b["primary_disease"].isin(ct), :] if len(ct) > 8: response = { "error": "Too many datasets. You can select no more than eight datasets." } return JsonResponse(response) ref = mm[5].split("=")[1] if ("," in mm[2].split("=")[1]): cell = mm[2].split("=")[1].split(",") else: cell = [mm[2].split("=")[1]] up = mm[3].split("=")[1] dn = mm[4].split("=")[1] select = select_part["primary_disease"].tolist() if ref == "EPIC": columns_list = ["EPIC_cellFractions." + i for i in cell] ref = API.DatabaseAPI("ref") elif ref == "LM": columns_list = ["LM_" + i for i in cell] ref = API.DatabaseAPI("LM_ref") elif ref == "QS": columns_list = ["QS_" + i for i in cell] ref = API.DatabaseAPI("QS_ref") else: response = {"error": "reference error"} return JsonResponse(response) cellID = select_part["cellID"].tolist() my_df_d = select_part.loc[:, columns_list] my_df_d.index = cellID my_df_t = my_df_d.T genes = ref.query_collection_var()["geneSymbol"] gg = ref.query_collection_gene_X_var_by_obs(genes) gg = pd.DataFrame(gg) gg.columns = ref.query_collection_obs()["celltype"] gg_mean = pd.DataFrame(gg.T.mean(axis=1)) gg_mean = gg_mean.loc[cell, :] expression = my_df_t.multiply(gg_mean.values) expression_t = expression.T expression_t = pd.DataFrame(expression_t.sum(axis=1), columns=["sum"]) expression_t = expression_t.sort_values(by=["sum"], ascending=False) number = expression_t.shape[0] number1 = int(number / 100 * (100 - int(up))) number2 = int(number / 100 * (100 - int(dn))) samples = expression_t.index.tolist() sample = [] for each in samples: names = each.split(".") sample.append(names[0] + "." + names[1] + "." + names[2]) up_sample = sample[:number1] dn_sample = sample[number2 + 1:] matches = {"Dead": 1, "Alive": 0, "-": 0} a = API.DatabaseAPI("survival") #up part my_dict_a = a.query_collection_obs() my_df_a = pd.DataFrame(my_dict_a) my_df_a = my_df_a.loc[my_df_a["sample"].isin(up_sample), :] OSEVENT = my_df_a["OSEVENT"].tolist() E = [matches[i] for i in OSEVENT] if st == "OS": T = my_df_a["OSDAY"].tolist() else: T = my_df_a["RFSDAY"].tolist() E_end_up = [E[i] for i in range(len(T)) if T[i] != "-"] T_end = [T[i] for i in range(len(T)) if T[i] != "-"] T_end = list(map(float, T_end)) T_end_up = list(map(lambda x: round(x / 30, 2), T_end)) kmf = KaplanMeierFitter() kmf.fit(T_end_up, E_end_up) sf = kmf.survival_function_.T xa = sf.columns.tolist() y1a = list(map(lambda x: round(x, 3), sf.values[0].tolist())) ci = kmf.confidence_interval_survival_function_.T.values y2a = list(map(lambda x: round(x, 3), ci[1].tolist())) y3a = list(map(lambda x: round(x, 3), ci[0].tolist())) xca = [T_end_up[i] for i in range(len(T_end_up)) if E_end_up[i] == 0] xca = list(map(float, xca)) yca = list( map(lambda x: round(x, 3), kmf.survival_function_at_times(xca).tolist())) #dn part my_dict_a = a.query_collection_obs() my_df_a = pd.DataFrame(my_dict_a) my_df_a = my_df_a.loc[my_df_a["sample"].isin(dn_sample), :] OSEVENT = my_df_a["OSEVENT"].tolist() E = [matches[i] for i in OSEVENT] if st == "OS": T = my_df_a["OSDAY"].tolist() else: T = my_df_a["RFSDAY"].tolist() E_end_dn = [E[i] for i in range(len(T)) if T[i] != "-"] T_end = [T[i] for i in range(len(T)) if T[i] != "-"] T_end = list(map(float, T_end)) T_end_dn = list(map(lambda x: round(x / 30, 2), T_end)) kmf = KaplanMeierFitter() kmf.fit(T_end_dn, E_end_dn) sf = kmf.survival_function_.T xb = sf.columns.tolist() y1b = list(map(lambda x: round(x, 3), sf.values[0].tolist())) ci = kmf.confidence_interval_survival_function_.T.values y2b = list(map(lambda x: round(x, 3), ci[1].tolist())) y3b = list(map(lambda x: round(x, 3), ci[0].tolist())) xcb = [T_end_dn[i] for i in range(len(T_end_dn)) if E_end_dn[i] == 0] xcb = list(map(float, xcb)) ycb = list( map(lambda x: round(x, 3), kmf.survival_function_at_times(xcb).tolist())) results = logrank_test(T_end_up, T_end_dn, event_observed_A=E_end_up, event_observed_B=E_end_dn) pValues1 = float(results.summary["p"].values) dfA = pd.DataFrame({'E': E_end_up, 'T': T_end_up, 'groupA': 1}) dfB = pd.DataFrame({'E': E_end_dn, 'T': T_end_dn, 'groupA': 0}) df = pd.concat([dfA, dfB]) cph = CoxPHFitter().fit(df, 'T', 'E') pValues2 = float(cph.summary["p"].values) response = { "data": [{ "pValues1": pValues1, "pValues2": pValues2 }, { "line": { "dash": "solid", "color": "red", "shape": "hv", "width": 2 }, "mode": "lines", "name": "", "type": "scatter", "x": xa, "y": y1a, "xaxis": "x1", "yaxis": "y1", "showlegend": False }, { "line": { "dash": "dash", "color": "red", "shape": "hv", "width": 2 }, "mode": "lines", "name": "", "type": "scatter", "x": xa, "y": y2a, "xaxis": "x1", "yaxis": "y1", "showlegend": False }, { "line": { "dash": "dash", "color": "red", "shape": "hv", "width": 2 }, "mode": "lines", "name": "", "type": "scatter", "x": xa, "y": y3a, "xaxis": "x1", "yaxis": "y1", "showlegend": False }, { "mode": "markers", "name": "", "text": "", "type": "scatter", "x": xca, "y": yca, "xaxis": "x1", "yaxis": "y1", "marker": { "size": 10, "color": "black", "symbol": "cross-thin-open", "opacity": 1, "sizeref": 1, "sizemode": "area" }, "showlegend": False }, { "line": { "dash": "solid", "color": "blue", "shape": "hv", "width": 2 }, "mode": "lines", "name": "", "type": "scatter", "x": xb, "y": y1b, "xaxis": "x1", "yaxis": "y1", "showlegend": False }, { "line": { "dash": "dash", "color": "blue", "shape": "hv", "width": 2 }, "mode": "lines", "name": "", "type": "scatter", "x": xb, "y": y2b, "xaxis": "x1", "yaxis": "y1", "showlegend": False }, { "line": { "dash": "dash", "color": "blue", "shape": "hv", "width": 2 }, "mode": "lines", "name": "", "type": "scatter", "x": xb, "y": y3b, "xaxis": "x1", "yaxis": "y1", "showlegend": False }, { "mode": "markers", "name": "", "text": "", "type": "scatter", "x": xcb, "y": ycb, "xaxis": "x1", "yaxis": "y1", "marker": { "size": 10, "color": "black", "symbol": "cross-thin-open", "opacity": 1, "sizeref": 1, "sizemode": "area" }, "showlegend": False }] } return JsonResponse(response)
S2 = data[data.Stage_group == 2] km2 = KM() km2.fit(S2.loc[:, "Time"], event_observed=S2.loc[:, 'Event'], label='Stage IV') ax = km1.plot(ci_show=False) km2.plot(ax=ax, ci_show=False) plt.xlabel('time') plt.ylabel('Survival probability estimate') plt.savefig('two_km_curves', dpi=300) # Let's compare the survival functions at 90, 180, 270, and 360 days # In[37]: survivals = pd.DataFrame([90, 180, 270, 360], columns=['time']) survivals.loc[:, 'Group 1'] = km1.survival_function_at_times( survivals['time']).values survivals.loc[:, 'Group 2'] = km2.survival_function_at_times( survivals['time']).values # In[38]: survivals # This makes clear the difference in survival between the Stage III and IV cancer groups in the dataset. # <a name='5-1'></a> # ## 5.1 Bonus: Log-Rank Test # # To say whether there is a statistical difference between the survival curves we can run the log-rank test. This test tells us the probability that we could observe this data if the two curves were the same. The derivation of the log-rank test is somewhat complicated, but luckily `lifelines` has a simple function to compute it. # # Run the next cell to compute a p-value using `lifelines.statistics.logrank_test`.
def filter_survival(filter_id): data_filtered = filtering(filter_id) data_filtered = data_filtered[data_filtered[filter_id['cell_full']] != "missing"] # Get the groups for ntiles and run the Kaplan Meier fitter for each of them # If the group_sizes are provided, use the binning function, otherwise the general ntiles if filter_id['group_sizes'] != None: data_filtered['rank'] = binning(data_filtered.sort_values(by=filter_id['cell_full']), filter_id['cell_full'], filter_id['group_sizes']) else: data_filtered['rank'] = ntiles(data_filtered[filter_id['cell_full']], filter_id['num_groups']) points = [] # OBS: checking the number of groups after filtering num_groups = len(uniq(data_filtered['rank'])) if num_groups < 2: raise ValueError('Number of groups must be at least two.') points_dfs = [] alive_dfs = [] for g in range(num_groups): kmf = KaplanMeierFitter() data = data_filtered[lambda row: row['rank'] == g+1] kmf.fit( data['T'], data['E'], label='Kaplan_Meier', ) df = pd.concat([ kmf.survival_function_, kmf.confidence_interval_survival_function_, ], axis=1) df['group'] = g+1 points_dfs += [df] alive_df = kmf.survival_function_at_times(data[data['E']==False]['T']).to_frame().reset_index() alive_df['group'] = g+1 alive_dfs += [alive_df] # Curate points and alive points points_df = pd.concat(points_dfs).reset_index().rename(columns={ 'index': 'time', 'Kaplan_Meier': 'fit', 'Kaplan_Meier_lower_0.95': 'lower', 'Kaplan_Meier_upper_0.95': 'upper', }) points = points_df.to_dict(orient='records') alive_points_df = pd.concat(alive_dfs).rename(columns={ 'index': 'time', 'Kaplan_Meier': 'fit', 'group': 'group', }) alive_points = alive_points_df.to_dict(orient='records') # Run multivarate analysis log_rank = multivariate_logrank_test(data_filtered['T'], data_filtered['rank'], data_filtered['E']) log = { 'test_statistic_logrank': log_rank.summary['test_statistic'][0], 'p_logrank': log_rank.summary['p'][0] } # Run cox regression cph = CoxPHFitter() cph.fit(data_filtered[['rank', 'T', 'E']], 'T', event_col='E') cox = { 'coef': cph.summary['exp(coef)'][0], 'lower': cph.summary['exp(coef) lower 95%'][0], 'upper': cph.summary['exp(coef) upper 95%'][0], 'p': cph.summary['p'][0] } # Replace infinite values with max or min probabilities if cox['upper'] == float('inf'): cox['upper'] = 1.0 if cox['lower'] == float('-inf'): cox['lower'] = 0.0 return {'points': points, 'log_rank': log, 'cox_regression': cox, 'live_points': alive_points}
class DGPSurv(gp.parameterized.Parameterized): def __init__(self, X, T, c, prediction_horizon, layer_dim=30, num_causes=1, num_inducing=100, calibration_fraction=0.5, calibrate=False): super(DGPSurv, self).__init__() self.prediction_horizon = prediction_horizon self.calibrate = calibrate # Refine inputs inclusion_criteria = (T >= self.prediction_horizon) | (c != 0) X_ = np.array(X)[inclusion_criteria, :] c_ = np.array( ((np.array(c)[inclusion_criteria] != 0) & (np.array(T)[inclusion_criteria] < self.prediction_horizon)) * np.array(c)[inclusion_criteria]) # Set all model attributes self.minmax_ = StandardScaler() self.X = torch.tensor(np.array(self.minmax_.fit_transform(X_))).float() self.T = torch.tensor(np.array(T).astype(float) / 365).float() self.c = torch.tensor(np.array(c_).astype(float)).float() self.num_inducing = min(num_inducing, self.X.shape[0]) self.num_causes = num_causes + 1 self.num_dim = self.X.shape[1] self.Xu = torch.from_numpy( kmeans2(self.X.numpy(), self.num_inducing, minit='points')[0]) # handle erroneous settings for the model's parameters try: self.layer_dim = layer_dim if self.layer_dim < 2: raise ValueError( "Bad inputs: number of intermediate dimensions must be greater than 2." ) except ValueError as ve: print(ve) # computes the weight for mean function of the first layer using a PCA transformation _, _, V = np.linalg.svd(self.X.numpy(), full_matrices=False) W = torch.from_numpy(V[:self.layer_dim, :]) mean_fn = LinearT(self.num_dim, self.layer_dim) mean_fn.linear.weight.data = W mean_fn.linear.weight.requires_grad_(False) self.mean_fn = mean_fn # Initialize the first DGP layer linear = torch.nn.Linear(self.num_dim, 20) pyro_linear_fn = lambda x: pyro.module("linear", linear)(x) kernel = gp.kernels.Matern32(input_dim=self.num_dim, lengthscale=torch.tensor(1.)) warped_kernel = gp.kernels.Warping(kernel, pyro_linear_fn) self.layer_0 = gp.models.VariationalSparseGP( self.X, None, gp.kernels.Matern52(self.num_dim, variance=torch.tensor(1.), lengthscale=torch.ones(self.num_dim)), #warped_kernel, Xu=self.Xu, likelihood=None, mean_function=self.mean_fn, latent_shape=torch.Size([self.layer_dim])) h = self.mean_fn(self.X).t() hu = self.mean_fn(self.Xu).t() self.layer_1 = gp.models.VariationalSparseGP( h, self.c, gp.kernels.Matern52(self.layer_dim, variance=torch.tensor(1.), lengthscale=torch.tensor(1.)), Xu=hu, likelihood=gp.likelihoods.MultiClass(num_classes=self.num_causes), latent_shape=torch.Size([self.num_causes])) #self.layer_0.u_scale_tril = self.layer_0.u_scale_tril * 1e-5 #self.layer_0.set_constraint("u_scale_tril", torch.distributions.constraints.lower_cholesky) if self.calibrate: self.kmf = KaplanMeierFitter() self.kmf.fit(T, event_observed=c) self.offset_probability = self.kmf.survival_function_at_times( times=[self.prediction_horizon])._values[0] self.calibration_fraction = calibration_fraction @autoname.name_count def model(self, X, c): self.layer_0.set_data(X, None) h_loc, h_var = self.layer_0.model() h = dist.Normal(h_loc, h_var.sqrt())() self.layer_1.set_data(h.t(), c) self.layer_1.model() @autoname.name_count def guide(self, X, c): self.layer_0.guide() self.layer_1.guide() # make prediction def forward(self, X_new): # because prediction is stochastic (due to Monte Carlo sample of hidden layer), # we make 100 prediction and take the most common one (as in [4]) pred = [] num_MC_samples = 100 for _ in range(num_MC_samples): h_loc, h_var = self.layer_0(X_new) h = dist.Normal(h_loc, h_var.sqrt())() f_loc, f_var = self.layer_1(h.t()) pred.append(f_loc) # change for multiclass return torch.stack(pred).mode(dim=0)[0] def train(self, num_epochs=5, num_iters=60, batch_size=1000, learning_rate=0.01): optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate) loss_fn = infer.TraceMeanField_ELBO().differentiable_loss self.loss = [] for i in range(num_epochs): self.loss.append( self.train_update(optimizer, loss_fn, batch_size, num_iters, i)) self.loss = np.array(self.loss).reshape((-1, 1)) if self.calibrate: print("Calibrating the trained model...") calibration_indexes = np.random.choice( list(range(self.X.shape[0])), int(np.ceil(self.calibration_fraction * self.X.shape[0]))) y_uncalibrated = self.predict_survival( self.X[calibration_indexes, :].detach().numpy(), calibrate=1) y_raw = np.log((1 - y_uncalibrated) / y_uncalibrated) self.calibration_constant = sigmoid_calibrate_survival_predictions( self, y_raw) print("Done training!") else: self.calibration_constant = 1 def train_update(self, optimizer, loss_fn, batch_size, num_iters, epoch): losses = [] for _ in range(num_iters): batch_indexes = np.random.choice(list(range(self.X.shape[0])), batch_size) features_ = self.X[batch_indexes, :] event_censor = self.c[batch_indexes] features_ = features_.reshape(-1, self.X.shape[1]) optimizer.zero_grad() loss = loss_fn(self.model, self.guide, features_, event_censor) losses.append(loss) loss.backward() optimizer.step() print("Train Epoch: {:2d} \t[Iteration: {:2d}] \tLoss: {:.6f}".format( epoch, _, loss)) return losses def predict_survival(self, X_new, calibrate=None): s_preds = [] y_pred = [] index = 0 base_size = 1000 predictor_size = np.min((X_new.shape[0], base_size)) num_batches_ = int(np.ceil(X_new.shape[0] / predictor_size)) if calibrate == None: calibration_factor = self.calibration_constant else: calibration_factor = calibrate for u in range(num_batches_): if (u == (num_batches_ - 1)) and (np.mod(X_new.shape[0], predictor_size) > 0): X_curr = np.array(X_new)[index:, :] else: X_curr = np.array(X_new)[index:index + predictor_size, :] X_new_numpy = self.minmax_.transform(X_curr) X_new_ = torch.tensor(X_new_numpy).float() f_output = self(X_new_).detach().numpy() if u == 0: s_preds = f_output else: s_preds = np.hstack((s_preds, f_output)) index += predictor_size for v in range(self.num_causes): y_pred.append( output_layer(constant=calibration_factor, y=s_preds[v, :])) y_pred = ((1 - np.array(y_pred)) / np.sum( (1 - np.array(y_pred)), axis=0)) return y_pred