def test(net, x, ground_truth=None): """ test the UNet model by computing F1 score Args: net: (nn.Module) UNet module x: (Tensor) with sized [#signals, 1 lead, signal_length] ground_truth: (Tensor) with sized [#signals, 4 segments, signal_length], 4 segments are background, p, qrs, and t Returns: plot: (plt object) intervals: (dict) with complex structure, see utils.val_utils.validation_duration_accuracy for more information """ net.eval() # input size should be (num_of_signals, 1, 500 * seconds) with torch.no_grad(): output = net(x) # output size should be (num_of_signals, 4, 500 * seconds) if ground_truth is not None: plot = predict_plotter(x[0][0], output[0], ground_truth[0]) else: plot = predict_plotter(x[0][0], output[0]) pred_ans = F.one_hot(output.argmax(1), num_classes=4).permute(0, 2, 1) output_onset_offset = onset_offset_generator(pred_ans[:, :3, :]) intervals = validation_duration_accuracy(output_onset_offset) return plot, intervals
def qrs_seperation(ekg_sig, final_preds): turn_point = get_signals_turning_point_by_rdp(ekg_sig, load=True) ekg_sig = ekg_sig.cpu().numpy() """qrs segmentation""" onset_offset = onset_offset_generator(final_preds) qrs_interval = [] for i in range(onset_offset.shape[0]): qrs_interval.append([]) j = 0 while j < 4992: if onset_offset[i, 2, j] == -1: qrs_interval[i].append([j]) j += 1 while onset_offset[i, 2, j] == 0: j += 1 qrs_interval[i][-1].append(j) j += 1 enlarge_qrs = enlarge_qrs_list(qrs_interval) turning = [] for index in range(ekg_sig.shape[0]): turning.append([]) for j in range(len(enlarge_qrs[index])): filtered_peaks = list( filter( lambda i: i >= enlarge_qrs[index][j][0] and i <= enlarge_qrs[index][j][1], turn_point[index])) turning[index].append(filtered_peaks) idx = find_index_closest_to_value( ekg_sig[index, 0, filtered_peaks[1]:filtered_peaks[2]], ekg_sig[index, 0, filtered_peaks[0]]) idx = idx + filtered_peaks[1] - enlarge_qrs[index][j][0] pred = [] for i in range(len(turning)): pred.append({"q_duration": [], "r_duration": [], "s_duration": []}) mode = np.argmax(np.bincount([len(i) for i in turning[i]])) for j in range(len(turning[i])): if len(turning[i][j]) != mode: continue if mode >= 5: # q,r,s # find q duration q_end = find_index_closest_to_value( ekg_sig[i, 0, turning[i][j][1]:turning[i][j][2]], ekg_sig[i, 0, turning[i][j][0]]) q_end = q_end + turning[i][j][1] q_duration = q_end - turning[i][j][0] pred[i]["q_duration"].append(q_duration) # find s duration s_start = find_index_closest_to_value( ekg_sig[i, 0, turning[i][j][2]:turning[i][j][3]], ekg_sig[i, 0, turning[i][j][4]]) s_start = s_start + turning[i][j][2] s_duration = turning[i][j][4] - s_start pred[i]["s_duration"].append(s_duration) # find r duration r_start = q_end r_end = s_start r_duration = r_end - r_start pred[i]["r_duration"].append(r_duration) elif mode == 4: # q,r or r,s if ekg_sig[i, 0, turning[i][j][1]] > ekg_sig[i, 0, turning[i][j][2]]: pred[i]["q_duration"].append(0) # r, s # find s duration s_start = find_index_closest_to_value( ekg_sig[i, 0, turning[i][j][1]:turning[i][j][2]], ekg_sig[i, 0, turning[i][j][3]]) s_start = s_start + turning[i][j][1] s_duration = turning[i][j][3] - s_start pred[i]["s_duration"].append(s_duration) # find r duration r_end = s_start r_duration = r_end - turning[i][j][0] pred[i]["r_duration"].append(r_duration) else: if i == 84: print(turning[i][j][1], turning[i][j][2]) # q, r pred[i]["s_duration"].append(0) # find q duration q_end = find_index_closest_to_value( ekg_sig[i, 0, turning[i][j][1]:turning[i][j][2]], ekg_sig[i, 0, turning[i][j][0]]) q_end = q_end + turning[i][j][1] q_duration = q_end - turning[i][j][0] pred[i]["q_duration"].append(q_duration) # find r duration r_start = q_end r_duration = turning[i][j][3] - r_start pred[i]["r_duration"].append(r_duration) elif mode <= 3: # only q or r if ekg_sig[i, 0, turning[i][j][1]] > ekg_sig[i, 0, turning[i][j][0]]: # r pred[i]["q_duration"].append(0) pred[i]["s_duration"].append(0) r_duration = turning[i][j][2] - turning[i][j][0] pred[i]["r_duration"].append(r_duration) else: # q pred[i]["r_duration"].append(0) pred[i]["s_duration"].append(0) q_duration = turning[i][j][2] - turning[i][j][0] pred[i]["q_duration"].append(q_duration) return pred
def test_retinanet_by_qrs(net): """ testing the CAL and ANE dataset q, r, s duration using rdp algorithm. Args: net: (nn.Module) Retinanet module """ ekg_sig = load_ANE_CAL(denoise=False, pre=False, nor=False) turn_point = get_signals_turning_point_by_rdp(ekg_sig, load=True) print(len(turn_point[0])) final_preds = [] ekg_sig = normalize(ekg_sig) for i in range(ekg_sig.size(0) // 128 + 1): _, _, pred_signals = test_retinanet(net, ekg_sig[i * 128:(i + 1) * 128, :, :], 4992, visual=False) final_preds.append(pred_signals) final_preds = torch.cat(final_preds, dim=0) ekg_sig = ekg_sig.cpu().numpy() onset_offset = onset_offset_generator(final_preds) qrs_interval = [] for i in range(onset_offset.shape[0]): qrs_interval.append([]) j = 0 while j < 4992: if onset_offset[i, 2, j] == -1: qrs_interval[i].append([j]) j += 1 while onset_offset[i, 2, j] == 0: j += 1 qrs_interval[i][-1].append(j) j += 1 enlarge_qrs = enlarge_qrs_list(qrs_interval) turning = [] for index in range(ekg_sig.shape[0]): turning.append([]) for j in range(len(enlarge_qrs[index])): filtered_peaks = list( filter( lambda i: i >= enlarge_qrs[index][j][0] and i <= enlarge_qrs[index][j][1], turn_point[index])) turning[index].append(filtered_peaks) idx = find_index_closest_to_value( ekg_sig[index, 0, filtered_peaks[1]:filtered_peaks[2]], ekg_sig[index, 0, filtered_peaks[0]]) idx = idx + filtered_peaks[1] - enlarge_qrs[index][j][0] pred = [] for i in range(len(turning)): pred.append({"q_duration": [], "r_duration": [], "s_duration": []}) mode = np.argmax(np.bincount([len(i) for i in turning[i]])) for j in range(len(turning[i])): if len(turning[i][j]) != mode: continue if mode >= 5: # q,r,s # find q duration q_end = find_index_closest_to_value( ekg_sig[i, 0, turning[i][j][1]:turning[i][j][2]], ekg_sig[i, 0, turning[i][j][0]]) q_end = q_end + turning[i][j][1] q_duration = q_end - turning[i][j][0] pred[i]["q_duration"].append(q_duration) # find s duration s_start = find_index_closest_to_value( ekg_sig[i, 0, turning[i][j][2]:turning[i][j][3]], ekg_sig[i, 0, turning[i][j][4]]) s_start = s_start + turning[i][j][2] s_duration = turning[i][j][4] - s_start pred[i]["s_duration"].append(s_duration) # find r duration r_start = q_end r_end = s_start r_duration = r_end - r_start pred[i]["r_duration"].append(r_duration) elif mode == 4: # q,r or r,s if ekg_sig[i, 0, turning[i][j][1]] > ekg_sig[i, 0, turning[i][j][2]]: pred[i]["q_duration"].append(0) # r, s # find s duration s_start = find_index_closest_to_value( ekg_sig[i, 0, turning[i][j][1]:turning[i][j][2]], ekg_sig[i, 0, turning[i][j][3]]) s_start = s_start + turning[i][j][1] s_duration = turning[i][j][3] - s_start pred[i]["s_duration"].append(s_duration) # find r duration r_end = s_start r_duration = r_end - turning[i][j][0] pred[i]["r_duration"].append(r_duration) else: if i == 84: print(turning[i][j][1], turning[i][j][2]) # q, r pred[i]["s_duration"].append(0) # find q duration q_end = find_index_closest_to_value( ekg_sig[i, 0, turning[i][j][1]:turning[i][j][2]], ekg_sig[i, 0, turning[i][j][0]]) q_end = q_end + turning[i][j][1] q_duration = q_end - turning[i][j][0] pred[i]["q_duration"].append(q_duration) # find r duration r_start = q_end r_duration = turning[i][j][3] - r_start pred[i]["r_duration"].append(r_duration) elif mode <= 3: # only q or r if ekg_sig[i, 0, turning[i][j][1]] > ekg_sig[i, 0, turning[i][j][0]]: # r pred[i]["q_duration"].append(0) pred[i]["s_duration"].append(0) r_duration = turning[i][j][2] - turning[i][j][0] pred[i]["r_duration"].append(r_duration) else: # q pred[i]["r_duration"].append(0) pred[i]["s_duration"].append(0) q_duration = turning[i][j][2] - turning[i][j][0] pred[i]["q_duration"].append(q_duration) standard_qrs = [] # ANE standard_qrs.append({"q_duration": 12, "r_duration": 52, "s_duration": 30}) standard_qrs.append({"q_duration": 12, "r_duration": 52, "s_duration": 30}) standard_qrs.append({"q_duration": 12, "r_duration": 52, "s_duration": 30}) #CAL standard_qrs.append({"q_duration": 0, "r_duration": 50, "s_duration": 50}) standard_qrs.append({"q_duration": 0, "r_duration": 50, "s_duration": 50}) standard_qrs.append({"q_duration": 0, "r_duration": 50, "s_duration": 50}) standard_qrs.append({"q_duration": 0, "r_duration": 50, "s_duration": 50}) standard_qrs.append({"q_duration": 0, "r_duration": 50, "s_duration": 50}) standard_qrs.append({"q_duration": 0, "r_duration": 56, "s_duration": 0}) standard_qrs.append({"q_duration": 0, "r_duration": 56, "s_duration": 0}) standard_qrs.append({"q_duration": 0, "r_duration": 56, "s_duration": 0}) standard_qrs.append({"q_duration": 56, "r_duration": 0, "s_duration": 0}) standard_qrs.append({"q_duration": 56, "r_duration": 0, "s_duration": 0}) standard_qrs.append({"q_duration": 56, "r_duration": 0, "s_duration": 0}) standard_qrs.append({"q_duration": 0, "r_duration": 18, "s_duration": 18}) standard_qrs.append({"q_duration": 0, "r_duration": 50, "s_duration": 50}) standard_qrs.append({"q_duration": 0, "r_duration": 50, "s_duration": 50}) mean_diff = np.zeros((3, 17)) for i in range(17): q_temp_mean = [] r_temp_mean = [] s_temp_mean = [] for j in range(5): q_temp_mean.append(np.mean(pred[i * 5 + j]["q_duration"])) r_temp_mean.append(np.mean(pred[i * 5 + j]["r_duration"])) s_temp_mean.append(np.mean(pred[i * 5 + j]["s_duration"])) mean_diff[0][ i] = np.mean(q_temp_mean) * 2 - standard_qrs[i]["q_duration"] mean_diff[1][ i] = np.mean(r_temp_mean) * 2 - standard_qrs[i]["r_duration"] mean_diff[2][ i] = np.mean(s_temp_mean) * 2 - standard_qrs[i]["s_duration"] #print(pd.DataFrame(mean_diff.T, columns=["q","r","s"])) #print(np.mean(mean_diff, axis=1)) #print(np.std(mean_diff, axis=1, ddof=1)) mean_diff = removeworst(mean_diff, 4) mean_diff_mean = np.mean(mean_diff, axis=1) mean_diff_std = np.std(mean_diff, axis=1, ddof=1) print(mean_diff_mean) print(mean_diff_std)
def test_retinanet(net, x, input_length, ground_truth=None, visual=False): """ test the RetinaNet by any preprocessed signals. Args: net: (nn.Module) RetinaNet model x: (Tensor) with sized [#signals, 1 lead, values] input_length: (int) input length must dividable by 64 ground_truth: (Tensor) with sized [batch_size, #anchors, 2] Returns: plot: (pyplot) pyplot object interval: (list of dict) with sized [#signals], for more info about dict structure, you can see utils.val_utils.validation_duration_accuracy. """ net.eval() loc_preds, cls_preds = net(x) loc_preds = loc_preds.data.type(torch.FloatTensor) cls_preds = cls_preds.data.type(torch.FloatTensor) if ground_truth: loc_targets, cls_targets = ground_truth loc_targets = loc_targets.data.type(torch.FloatTensor) cls_targets = cls_targets.data.type(torch.LongTensor) batch_size = x.size(0) encoder = DataEncoder() pred_sigs = [] gt_sigs = [] for i in range(batch_size): boxes, labels, sco, is_found = encoder.decode(loc_preds[i], cls_preds[i], input_length, CLS_THRESH=0.425, NMS_THRESH=0.5) if is_found: boxes = boxes.ceil() xmin = boxes[:, 0].clamp(min=1) xmax = boxes[:, 1].clamp(max=input_length - 1) pred_sig = box_to_sig_generator(xmin, xmax, labels, input_length, background=False) else: pred_sig = torch.zeros(1, 4, input_length) if ground_truth: gt_boxes, gt_labels, gt_sco, gt_is_found = encoder.decode( loc_targets[i], one_hot_embedding(cls_targets[i], 4), input_length) gt_sig = box_to_sig_generator(gt_boxes[:, 0], gt_boxes[:, 1], gt_labels, input_length, background=False) gt_sigs.append(gt_sig) pred_sigs.append(pred_sig) pred_signals = torch.cat(pred_sigs, 0) pred_onset_offset = onset_offset_generator(pred_signals) plot = None if visual: if ground_truth is not None: for i in range(batch_size): plot = predict_plotter(x[i][0], pred_signals[i], ground_truth[i], name=str(i)) else: for i in range(batch_size): plot = predict_plotter(x[i][0], pred_signals[i], name=str(i)) if ground_truth: gt_signals = torch.cat(gt_sigs, 0) gt_onset_offset = onset_offset_generator(gt_signals) TP, FP, FN = validation_accuracy(pred_onset_offset, gt_onset_offset) intervals = validation_duration_accuracy(pred_onset_offset[:, 1:, :]) return plot, intervals, pred_signals
def eval_retinanet(model, dataloader): """ the evaluation function that can be used during RetinaNet training. Args: model: (nn.Module) RetinaNet module variable dataloader: (DataLoader) validation dataloader Returns: Se: (float) TP / (TP + FN) PPV: (float) TP / (TP + FP) F1: (float) 2 * Se * PPV / (Se + PPV) """ input_length = 3968 model.eval() pred_sigs = [] gt_sigs = [] sigs = [] for batch_idx, (inputs, loc_targets, cls_targets, gt_boxes, gt_labels, gt_peaks) in enumerate(dataloader): batch_size = inputs.size(0) inputs = torch.autograd.Variable(inputs.cuda()) loc_targets = torch.autograd.Variable(loc_targets.cuda()) cls_targets = torch.autograd.Variable(cls_targets.cuda()) inputs = inputs.unsqueeze(1) sigs.append(inputs) loc_preds, cls_preds = model(inputs) loc_preds = loc_preds.data.squeeze().type( torch.FloatTensor) # sized [#anchors * 3, 2] cls_preds = cls_preds.data.squeeze().type( torch.FloatTensor) # sized [#ahchors * 3, 3] loc_targets = loc_targets.data.squeeze().type(torch.FloatTensor) cls_targets = cls_targets.data.squeeze().type(torch.LongTensor) # decoder only process data 1 by 1. encoder = DataEncoder() for i in range(batch_size): boxes, labels, sco, is_found = encoder.decode( loc_preds[i], cls_preds[i], input_length) #ground truth decode using another method gt_boxes_tensor = torch.tensor(gt_boxes[i]) gt_labels_tensor = torch.tensor(gt_labels[i]) xmin = gt_boxes_tensor[:, 0].clamp(min=1) xmax = gt_boxes_tensor[:, 1].clamp(max=input_length - 1) gt_sig = box_to_sig_generator(xmin, xmax, gt_labels_tensor, input_length, background=False) if is_found: boxes = boxes.ceil() xmin = boxes[:, 0].clamp(min=1) xmax = boxes[:, 1].clamp(max=input_length - 1) # there is no background anchor on predict labels pred_sig = box_to_sig_generator(xmin, xmax, labels, input_length, background=False) else: pred_sig = torch.zeros(1, 4, input_length) pred_sigs.append(pred_sig) gt_sigs.append(gt_sig) sigs = torch.cat(sigs, 0) pred_signals = torch.cat(pred_sigs, 0) gt_signals = torch.cat(gt_sigs, 0) plot = predict_plotter(sigs[0][0], pred_signals[0], gt_signals[0]) #wandb.log({"visualization": plot}) pred_onset_offset = onset_offset_generator(pred_signals) gt_onset_offset = onset_offset_generator(gt_signals) TP, FP, FN = validation_accuracy(pred_onset_offset, gt_onset_offset) Se = TP / (TP + FN) PPV = TP / (TP + FP) F1 = 2 * Se * PPV / (Se + PPV) print("Se: {} PPV: {} F1 score: {}".format(Se, PPV, F1)) wandb.log({"Se": Se, "PPV": PPV, "F1": F1}) return Se, PPV, F1
def eval_unet(net, loader, device): """ the evaluation function that can be used during UNet training. Args: net: (nn.Module) UNet module variable loader: (DataLoader) validation dataloader device: (str) using GPU or CPU Returns: average loss: (float) average loss within validation set pointwise acc: (float) pointwise evaluation Se: (float) TP / (TP + FN) PPV: (float) TP / (TP + FP) F1: (float) 2 * Se * PPV / (Se + PPV) ret: (list of dict) see validation_duration_accuracy for more detail """ net.eval() n_val = len(loader) tot = 0 correct = 0 total = 0 TP = 0 FP = 0 FN = 0 with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False, ncols=100) as pbar: for batch in loader: x, ground_truth = batch[0], batch[1] x = x.to(device, dtype=torch.float32) ground_truth = ground_truth.to(device, dtype=torch.float32) with torch.no_grad(): pred = net(x) tot += F.binary_cross_entropy_with_logits(pred, ground_truth).item() # (batch_size, channels, data) pred_ans = F.one_hot(pred.argmax(1), num_classes=4).permute(0, 2, 1) correct += pred_ans.eq(ground_truth).sum().item() total += ground_truth.shape[0] * ground_truth.shape[ 1] * ground_truth.shape[2] # only use first 3 channels because first three channels will produce all onsets/offsets pred_onset_offset = onset_offset_generator(pred_ans[:, :3, :]) gt_onset_offset = onset_offset_generator(ground_truth[:, :3, :]) tp, fp, fn = validation_accuracy(pred_onset_offset, gt_onset_offset) ret = validation_duration_accuracy(pred_onset_offset) TP += tp FP += fp FN += fn pbar.update() Se = TP / (TP + FN) PPV = TP / (TP + FP) F1 = 2 * Se * PPV / (Se + PPV) return tot / n_val, correct / total, Se, PPV, F1, ret