def _predict_helper(main_args, device): from division_detection.predict import single_tp_nonblocking_predict from division_detection.model import fetch_model helper_model_name, helper_t_predict, helper_chunk_size = main_args model, model_spec = fetch_model(helper_model_name, device=device) single_tp_nonblocking_predict(model, "{}.h5".format(helper_model_name), helper_t_predict, device, in_mem=False, chunk_size=helper_chunk_size)
def make_predictions_at_t(model_name, t_predict, device='/gpu:0', in_mem=True, chunk_size=(200, 150, 150)): """ Helper functions, runs predictions for a single timepoint """ with tf.device(device): print("Loading model") model, model_spec = fetch_model(model_name, device=device) print("beginning prediction") single_tp_nonblocking_predict(model, '{}.h5'.format(model_name), t_predict, device=device, in_mem=in_mem, chunk_size=chunk_size)
def make_predictions_by_t(model_name, device='/gpu:0', in_mem=True, chunk_size=(200, 150, 150)): """ Essentially calls make_predictions_at_t for every t, kinda """ # fetch the number of timepoints from division_detection.vol_preprocessing import VOL_DIR_H5 num_vols = len(os.listdir(VOL_DIR_H5)) with tf.device(device): print("Loading model") model, model_spec = fetch_model(model_name, device=device) for t_predict in range(3, num_vols - 4): print("beginning prediction for {}".format(t_predict)) single_tp_nonblocking_predict(model, '{}.h5'.format(model_name), t_predict, device=device, in_mem=in_mem, chunk_size=chunk_size)
def _slurm_predict_helper_general(timepoint, in_dir, model_name, chunk_size): model, model_spec = fetch_model(model_name, device='/gpu:0') single_tp_nonblocking_predict_general(model, model_name, in_dir, timepoint, '/gpu:0', chunk_size=chunk_size)
def _slurm_predict_helper(timepoint, model_name, chunk_size): model, model_spec = fetch_model(model_name, device='/gpu:0') single_tp_nonblocking_predict(model, model_name, timepoint, '/gpu:0', in_mem=False, chunk_size=chunk_size)
def _local_predict_helper_general(timepoints, in_dir, model_name, chunk_size, device): model, model_spec = fetch_model(model_name, device=device) for t_predict in timepoints: single_tp_nonblocking_predict_general(model, model_name, in_dir, t_predict, device, chunk_size=chunk_size)
def _predict_local_helper(timepoints, model_name, chunk_size, device): model, model_spec = fetch_model(model_name, device=device) for t_predict in timepoints: single_tp_nonblocking_predict(model, "{}.h5".format(model_name), t_predict, device, in_mem=False, chunk_size=chunk_size)
def pipeline_analyze(model_name, partials=True, test=True): """ Pipelined analysis method """ model, _ = fetch_model(model_name) if partials: from division_detection.vol_preprocessing import SPLIT_PARTIALS_PATH_TEMPLATE, REC_FIELD_SHAPE if test: partials_path = SPLIT_PARTIALS_PATH_TEMPLATE.format('test') else: partials_path = SPLIT_PARTIALS_PATH_TEMPLATE.format('train') with h5py.File(partials_path, 'r') as partials_file: # [n_samples] + REC_FIELD_SHAPE partial_cutouts = partials_file[str( tuple(REC_FIELD_SHAPE))]['cutouts'][:] # [n_samples,] labels = partials_file[str(tuple(REC_FIELD_SHAPE))]['labels'][:] raw_predictions = model.predict(partial_cutouts).squeeze() # evaluate PR on fully annotated validation volumes else: annotations = fetch_validation_annotations() valid_tps = np.unique(annotations[:, 0]).astype(np.int32) prediction_path = '/nrs/turaga/bergera/division_detection/prediction_outbox/{}.h5'.format( model_name) gt_path = os.path.expanduser( '~/data/div_detect/full_res_gt_vols/validation.h5') if not os.path.exists(prediction_path): raise RuntimeError("Predictions file missing") # flattened over all volumes raw_predictions = [] labels = [] with h5py.File(prediction_path) as predictions_file, h5py.File( gt_path) as gt_file: predictions = predictions_file['predictions'] for timept in valid_tps: tp_predict = predictions[timept] gt_vol = gt_file[str(timept)][:] if tp_predict.sum() > 0: raw_predictions.append(tp_predict.ravel()) labels.append(gt_vol.ravel()) else: warn( "No predictions found for timepoint {}".format(timept)) if len(raw_predictions) == 0: raise RuntimeError("No validation predictions found") raw_predictions = np.concatenate(raw_predictions) labels = np.concatenate(labels) class_predictions = (raw_predictions > 0.5).astype(int) correct_predictions = class_predictions == labels test_accuracy = np.sum(correct_predictions) / float( len(correct_predictions)) n_pos_samples = np.sum(labels) n_neg_samples = np.sum(np.logical_not(labels)) print("Achieved {} test set accuracy".format(test_accuracy)) print("Test set contains {} positive examples and {} negative examples". format(n_pos_samples, n_neg_samples)) print("Computing precision recall curve") precision, recall, thresholds = precision_recall_curve( labels.ravel(), raw_predictions.ravel(), pos_label=1) precision_recall_dict = { 'precision': precision, 'recall': recall, 'thresholds': thresholds } print("Computing ROC curve") false_pos_rate, true_pos_rate, thresholds = roc_curve( labels.ravel(), raw_predictions.ravel(), pos_label=1) roc_dict = { 'false_pos_rate': false_pos_rate, 'true_pos_rate': true_pos_rate, 'thresholds': thresholds } print('Computing confusion matrix') decision_thresholds = [0.1, 0.3, 0.5, 0.9, 0.95] confusion_matrices = { thresh: confusion_matrix(labels.ravel(), raw_predictions.ravel() > thresh) for thresh in decision_thresholds } analysis_results = { 'pr_curve': precision_recall_dict, 'roc_curve': roc_dict, 'confusion_matrices': confusion_matrices } for thresh, cm in iteritems(confusion_matrices): norm_cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] print( "Normalized confusion matrix at decision threshold of {}:".format( thresh)) print(norm_cm) return analysis_results