def get_prediction(ANN, timed_Xtest, timed_ytest): # timed_ypred = mdl.get_pred_timed(ANN, timed_Xtest, data.drop(['label','shock_ind'],axis=1)) timed_ypred = mdl.get_pred_timed(ANN, timed_Xtest, data.drop(['label'], axis=1)) # raw_proba = mdl.get_prob_timed(ANN, timed_Xtest, data.drop(['label','shock_ind'],axis=1)) raw_proba = mdl.get_prob_timed(ANN, timed_Xtest, data.drop(['label'], axis=1)) # timed_ypred = dat.append_data_to_timed(timed_ypred, data, ['x', 'y', 'z', 'rho']) # raw_proba = dat.append_data_to_timed(raw_proba, data, ['x', 'y', 'z', 'rho']) #variations pred_variations = dat.get_var(timed_ypred) true_variations = dat.get_var(timed_ytest) true_variations = dat.get_category(true_variations) pred_variations = dat.get_closest_var_by_cat( true_variations, dat.get_category(pred_variations)) # # true_variations = dat.append_data_to_timed(true_variations, data, ['x', 'y', 'z', 'rho']) # pred_variations = dat.append_data_to_timed(pred_variations, data, ['x', 'y', 'z', 'rho']) #crossings reference true_crossings = dat.crossings_from_var(true_variations) return timed_ypred, raw_proba, true_variations, pred_variations, true_crossings
def get_corrected_prediction(timed_ypred, raw_proba, true_variations, pred_variations): timed_ycorr = dat.get_corrected_pred(pred_variations, timed_ypred, raw_proba, dt_corr_pred) corr_variations = dat.get_var(timed_ycorr) corr_variations = dat.corrected_var( corr_variations, 15) #deletes variations faster than 15s corr_variations = dat.get_closest_var_by_cat( true_variations, dat.get_category(corr_variations)) # timed_ycorr = dat.append_data_to_timed(timed_ycorr, data, ['x', 'y', 'z', 'rho']) # corr_variations = dat.append_data_to_timed(corr_variations, data, ['x', 'y', 'z', 'rho']) corr_crossings = dat.crossings_from_var(corr_variations) return timed_ycorr, corr_variations, corr_crossings
def run_correction(): global timed_ycorr global corr_variations global Dt_corr_pred global corr_crossings Dt_corr_pred = int(Dt_var_selection.get()) timed_ycorr = dat.get_corrected_pred(pred_variations,timed_ypred, raw_proba, Dt_corr_pred) corr_variations = dat.get_var(timed_ycorr) corr_variations = dat.get_closest_var_by_cat(true_variations, dat.get_category(corr_variations)) timed_ycorr = dat.append_data_to_timed(timed_ycorr, data, ['x', 'y', 'z', 'rho']) corr_variations = dat.append_data_to_timed(corr_variations, data, ['x', 'y', 'z', 'rho']) corr_crossings = dat.crossings_from_var(corr_variations) corr_crossings = dat.get_closest_cross(true_crossings, corr_crossings)
def pred_from_unseen(model, unseen_data, scale_data, dt_corr, dt_density): unseen_data = unseen_data.fillna(scale_data.median()) init_pred = mdl.get_pred_timed(model, unseen_data, scale_data) proba = mdl.get_prob_timed(model, unseen_data, scale_data) init_var = dat.get_var(init_pred) init_var = dat.get_category(init_var) corr_pred = dat.get_corrected_pred(init_var, init_pred, proba, dt_corr) vcorr = dat.get_category(dat.get_var(corr_pred)) vcorr = dat.corrected_var(vcorr, 15) #deletes variations faster than 15s corr_crossings = dat.crossings_from_var(vcorr) corr_pred = dat.crossings_density(corr_pred, corr_crossings, dt_density) final_crossings = dat.final_list(corr_pred) return corr_pred, vcorr, final_crossings
def run_prediction(): global timed_ypred global raw_proba timed_ypred = mdl.get_pred_timed(ANN, timed_Xtest, data.drop('label',axis=1)) raw_proba = mdl.get_prob_timed(ANN, timed_Xtest, data.drop('label',axis=1)) timed_ypred = dat.append_data_to_timed(timed_ypred, data, ['x', 'y', 'z', 'rho']) raw_proba = dat.append_data_to_timed(raw_proba, data, ['x', 'y', 'z', 'rho']) global true_variations global pred_variations pred_variations = dat.get_var(timed_ypred) true_variations = dat.get_var(timed_ytest) true_variations = dat.get_category(true_variations) pred_variations = dat.get_closest_var_by_cat(true_variations, dat.get_category(pred_variations)) true_variations = dat.append_data_to_timed(true_variations, data, ['x', 'y', 'z', 'rho']) pred_variations = dat.append_data_to_timed(pred_variations, data, ['x', 'y', 'z', 'rho']) global true_crossings true_crossings = dat.crossings_from_var(true_variations)