Beispiel #1
0
def test(net, x, ground_truth=None):
    """
    test the UNet model by computing F1 score

    Args:
        net: (nn.Module) UNet module
        x: (Tensor) with sized [#signals, 1 lead, signal_length]
        ground_truth: (Tensor) with sized [#signals, 4 segments, signal_length], 4 segments are background, p, qrs, and t

    Returns:
        plot: (plt object)
        intervals: (dict) with complex structure, see utils.val_utils.validation_duration_accuracy for more information
    """
    net.eval()
    # input size should be (num_of_signals, 1, 500 * seconds)
    with torch.no_grad():
        output = net(x)
    # output size should be (num_of_signals, 4, 500 * seconds)
    if ground_truth is not None:
        plot = predict_plotter(x[0][0], output[0], ground_truth[0])
    else:
        plot = predict_plotter(x[0][0], output[0])

    pred_ans = F.one_hot(output.argmax(1), num_classes=4).permute(0, 2, 1)

    output_onset_offset = onset_offset_generator(pred_ans[:, :3, :])
    intervals = validation_duration_accuracy(output_onset_offset)
    return plot, intervals
def qrs_seperation(ekg_sig, final_preds):
    turn_point = get_signals_turning_point_by_rdp(ekg_sig, load=True)
    ekg_sig = ekg_sig.cpu().numpy()
    """qrs segmentation"""
    onset_offset = onset_offset_generator(final_preds)
    qrs_interval = []
    for i in range(onset_offset.shape[0]):
        qrs_interval.append([])
        j = 0
        while j < 4992:
            if onset_offset[i, 2, j] == -1:
                qrs_interval[i].append([j])
                j += 1
                while onset_offset[i, 2, j] == 0:
                    j += 1
                qrs_interval[i][-1].append(j)
            j += 1
    enlarge_qrs = enlarge_qrs_list(qrs_interval)

    turning = []
    for index in range(ekg_sig.shape[0]):
        turning.append([])
        for j in range(len(enlarge_qrs[index])):
            filtered_peaks = list(
                filter(
                    lambda i: i >= enlarge_qrs[index][j][0] and i <=
                    enlarge_qrs[index][j][1], turn_point[index]))
            turning[index].append(filtered_peaks)
            idx = find_index_closest_to_value(
                ekg_sig[index, 0, filtered_peaks[1]:filtered_peaks[2]],
                ekg_sig[index, 0, filtered_peaks[0]])
            idx = idx + filtered_peaks[1] - enlarge_qrs[index][j][0]

    pred = []
    for i in range(len(turning)):
        pred.append({"q_duration": [], "r_duration": [], "s_duration": []})
        mode = np.argmax(np.bincount([len(i) for i in turning[i]]))
        for j in range(len(turning[i])):
            if len(turning[i][j]) != mode:
                continue
            if mode >= 5:
                # q,r,s
                # find q duration
                q_end = find_index_closest_to_value(
                    ekg_sig[i, 0, turning[i][j][1]:turning[i][j][2]],
                    ekg_sig[i, 0, turning[i][j][0]])
                q_end = q_end + turning[i][j][1]
                q_duration = q_end - turning[i][j][0]
                pred[i]["q_duration"].append(q_duration)
                # find s duration
                s_start = find_index_closest_to_value(
                    ekg_sig[i, 0, turning[i][j][2]:turning[i][j][3]],
                    ekg_sig[i, 0, turning[i][j][4]])
                s_start = s_start + turning[i][j][2]
                s_duration = turning[i][j][4] - s_start
                pred[i]["s_duration"].append(s_duration)
                # find r duration
                r_start = q_end
                r_end = s_start
                r_duration = r_end - r_start
                pred[i]["r_duration"].append(r_duration)
            elif mode == 4:
                # q,r or r,s
                if ekg_sig[i, 0, turning[i][j][1]] > ekg_sig[i, 0,
                                                             turning[i][j][2]]:
                    pred[i]["q_duration"].append(0)
                    # r, s
                    # find s duration
                    s_start = find_index_closest_to_value(
                        ekg_sig[i, 0, turning[i][j][1]:turning[i][j][2]],
                        ekg_sig[i, 0, turning[i][j][3]])
                    s_start = s_start + turning[i][j][1]
                    s_duration = turning[i][j][3] - s_start
                    pred[i]["s_duration"].append(s_duration)
                    # find r duration
                    r_end = s_start
                    r_duration = r_end - turning[i][j][0]
                    pred[i]["r_duration"].append(r_duration)
                else:
                    if i == 84:
                        print(turning[i][j][1], turning[i][j][2])
                    # q, r
                    pred[i]["s_duration"].append(0)
                    # find q duration
                    q_end = find_index_closest_to_value(
                        ekg_sig[i, 0, turning[i][j][1]:turning[i][j][2]],
                        ekg_sig[i, 0, turning[i][j][0]])
                    q_end = q_end + turning[i][j][1]
                    q_duration = q_end - turning[i][j][0]
                    pred[i]["q_duration"].append(q_duration)
                    # find r duration
                    r_start = q_end
                    r_duration = turning[i][j][3] - r_start
                    pred[i]["r_duration"].append(r_duration)
            elif mode <= 3:
                # only q or r
                if ekg_sig[i, 0, turning[i][j][1]] > ekg_sig[i, 0,
                                                             turning[i][j][0]]:
                    # r
                    pred[i]["q_duration"].append(0)
                    pred[i]["s_duration"].append(0)
                    r_duration = turning[i][j][2] - turning[i][j][0]
                    pred[i]["r_duration"].append(r_duration)
                else:
                    # q
                    pred[i]["r_duration"].append(0)
                    pred[i]["s_duration"].append(0)
                    q_duration = turning[i][j][2] - turning[i][j][0]
                    pred[i]["q_duration"].append(q_duration)
    return pred
Beispiel #3
0
def test_retinanet_by_qrs(net):
    """
    testing the CAL and ANE dataset q, r, s duration using rdp algorithm.

    Args:
        net: (nn.Module) Retinanet module
    """
    ekg_sig = load_ANE_CAL(denoise=False, pre=False, nor=False)
    turn_point = get_signals_turning_point_by_rdp(ekg_sig, load=True)
    print(len(turn_point[0]))

    final_preds = []
    ekg_sig = normalize(ekg_sig)
    for i in range(ekg_sig.size(0) // 128 + 1):
        _, _, pred_signals = test_retinanet(net,
                                            ekg_sig[i * 128:(i + 1) *
                                                    128, :, :],
                                            4992,
                                            visual=False)
        final_preds.append(pred_signals)
    final_preds = torch.cat(final_preds, dim=0)
    ekg_sig = ekg_sig.cpu().numpy()

    onset_offset = onset_offset_generator(final_preds)
    qrs_interval = []
    for i in range(onset_offset.shape[0]):
        qrs_interval.append([])
        j = 0
        while j < 4992:
            if onset_offset[i, 2, j] == -1:
                qrs_interval[i].append([j])
                j += 1
                while onset_offset[i, 2, j] == 0:
                    j += 1
                qrs_interval[i][-1].append(j)
            j += 1

    enlarge_qrs = enlarge_qrs_list(qrs_interval)

    turning = []
    for index in range(ekg_sig.shape[0]):
        turning.append([])
        for j in range(len(enlarge_qrs[index])):
            filtered_peaks = list(
                filter(
                    lambda i: i >= enlarge_qrs[index][j][0] and i <=
                    enlarge_qrs[index][j][1], turn_point[index]))
            turning[index].append(filtered_peaks)
            idx = find_index_closest_to_value(
                ekg_sig[index, 0, filtered_peaks[1]:filtered_peaks[2]],
                ekg_sig[index, 0, filtered_peaks[0]])
            idx = idx + filtered_peaks[1] - enlarge_qrs[index][j][0]

    pred = []
    for i in range(len(turning)):
        pred.append({"q_duration": [], "r_duration": [], "s_duration": []})
        mode = np.argmax(np.bincount([len(i) for i in turning[i]]))
        for j in range(len(turning[i])):
            if len(turning[i][j]) != mode:
                continue
            if mode >= 5:
                # q,r,s
                # find q duration
                q_end = find_index_closest_to_value(
                    ekg_sig[i, 0, turning[i][j][1]:turning[i][j][2]],
                    ekg_sig[i, 0, turning[i][j][0]])
                q_end = q_end + turning[i][j][1]
                q_duration = q_end - turning[i][j][0]
                pred[i]["q_duration"].append(q_duration)
                # find s duration
                s_start = find_index_closest_to_value(
                    ekg_sig[i, 0, turning[i][j][2]:turning[i][j][3]],
                    ekg_sig[i, 0, turning[i][j][4]])
                s_start = s_start + turning[i][j][2]
                s_duration = turning[i][j][4] - s_start
                pred[i]["s_duration"].append(s_duration)
                # find r duration
                r_start = q_end
                r_end = s_start
                r_duration = r_end - r_start
                pred[i]["r_duration"].append(r_duration)
            elif mode == 4:
                # q,r or r,s
                if ekg_sig[i, 0, turning[i][j][1]] > ekg_sig[i, 0,
                                                             turning[i][j][2]]:
                    pred[i]["q_duration"].append(0)
                    # r, s
                    # find s duration
                    s_start = find_index_closest_to_value(
                        ekg_sig[i, 0, turning[i][j][1]:turning[i][j][2]],
                        ekg_sig[i, 0, turning[i][j][3]])
                    s_start = s_start + turning[i][j][1]
                    s_duration = turning[i][j][3] - s_start
                    pred[i]["s_duration"].append(s_duration)
                    # find r duration
                    r_end = s_start
                    r_duration = r_end - turning[i][j][0]
                    pred[i]["r_duration"].append(r_duration)
                else:
                    if i == 84:
                        print(turning[i][j][1], turning[i][j][2])
                    # q, r
                    pred[i]["s_duration"].append(0)
                    # find q duration
                    q_end = find_index_closest_to_value(
                        ekg_sig[i, 0, turning[i][j][1]:turning[i][j][2]],
                        ekg_sig[i, 0, turning[i][j][0]])
                    q_end = q_end + turning[i][j][1]
                    q_duration = q_end - turning[i][j][0]
                    pred[i]["q_duration"].append(q_duration)
                    # find r duration
                    r_start = q_end
                    r_duration = turning[i][j][3] - r_start
                    pred[i]["r_duration"].append(r_duration)
            elif mode <= 3:
                # only q or r
                if ekg_sig[i, 0, turning[i][j][1]] > ekg_sig[i, 0,
                                                             turning[i][j][0]]:
                    # r
                    pred[i]["q_duration"].append(0)
                    pred[i]["s_duration"].append(0)
                    r_duration = turning[i][j][2] - turning[i][j][0]
                    pred[i]["r_duration"].append(r_duration)
                else:
                    # q
                    pred[i]["r_duration"].append(0)
                    pred[i]["s_duration"].append(0)
                    q_duration = turning[i][j][2] - turning[i][j][0]
                    pred[i]["q_duration"].append(q_duration)

    standard_qrs = []
    # ANE
    standard_qrs.append({"q_duration": 12, "r_duration": 52, "s_duration": 30})
    standard_qrs.append({"q_duration": 12, "r_duration": 52, "s_duration": 30})
    standard_qrs.append({"q_duration": 12, "r_duration": 52, "s_duration": 30})
    #CAL
    standard_qrs.append({"q_duration": 0, "r_duration": 50, "s_duration": 50})
    standard_qrs.append({"q_duration": 0, "r_duration": 50, "s_duration": 50})
    standard_qrs.append({"q_duration": 0, "r_duration": 50, "s_duration": 50})
    standard_qrs.append({"q_duration": 0, "r_duration": 50, "s_duration": 50})
    standard_qrs.append({"q_duration": 0, "r_duration": 50, "s_duration": 50})
    standard_qrs.append({"q_duration": 0, "r_duration": 56, "s_duration": 0})
    standard_qrs.append({"q_duration": 0, "r_duration": 56, "s_duration": 0})
    standard_qrs.append({"q_duration": 0, "r_duration": 56, "s_duration": 0})
    standard_qrs.append({"q_duration": 56, "r_duration": 0, "s_duration": 0})
    standard_qrs.append({"q_duration": 56, "r_duration": 0, "s_duration": 0})
    standard_qrs.append({"q_duration": 56, "r_duration": 0, "s_duration": 0})
    standard_qrs.append({"q_duration": 0, "r_duration": 18, "s_duration": 18})
    standard_qrs.append({"q_duration": 0, "r_duration": 50, "s_duration": 50})
    standard_qrs.append({"q_duration": 0, "r_duration": 50, "s_duration": 50})

    mean_diff = np.zeros((3, 17))
    for i in range(17):
        q_temp_mean = []
        r_temp_mean = []
        s_temp_mean = []
        for j in range(5):
            q_temp_mean.append(np.mean(pred[i * 5 + j]["q_duration"]))
            r_temp_mean.append(np.mean(pred[i * 5 + j]["r_duration"]))
            s_temp_mean.append(np.mean(pred[i * 5 + j]["s_duration"]))
        mean_diff[0][
            i] = np.mean(q_temp_mean) * 2 - standard_qrs[i]["q_duration"]
        mean_diff[1][
            i] = np.mean(r_temp_mean) * 2 - standard_qrs[i]["r_duration"]
        mean_diff[2][
            i] = np.mean(s_temp_mean) * 2 - standard_qrs[i]["s_duration"]
    #print(pd.DataFrame(mean_diff.T, columns=["q","r","s"]))
    #print(np.mean(mean_diff, axis=1))
    #print(np.std(mean_diff, axis=1, ddof=1))
    mean_diff = removeworst(mean_diff, 4)
    mean_diff_mean = np.mean(mean_diff, axis=1)
    mean_diff_std = np.std(mean_diff, axis=1, ddof=1)
    print(mean_diff_mean)
    print(mean_diff_std)
Beispiel #4
0
def test_retinanet(net, x, input_length, ground_truth=None, visual=False):
    """
    test the RetinaNet by any preprocessed signals.

    Args:
        net:            (nn.Module) RetinaNet model
        x:              (Tensor) with sized [#signals, 1 lead, values]
        input_length:   (int) input length must dividable by 64
        ground_truth:   (Tensor) with sized [batch_size, #anchors, 2]

    Returns:
        plot:       (pyplot) pyplot object
        interval:   (list of dict) with sized [#signals], for more info about dict structure, you can see utils.val_utils.validation_duration_accuracy.
    """
    net.eval()
    loc_preds, cls_preds = net(x)

    loc_preds = loc_preds.data.type(torch.FloatTensor)
    cls_preds = cls_preds.data.type(torch.FloatTensor)

    if ground_truth:
        loc_targets, cls_targets = ground_truth
        loc_targets = loc_targets.data.type(torch.FloatTensor)
        cls_targets = cls_targets.data.type(torch.LongTensor)

    batch_size = x.size(0)
    encoder = DataEncoder()

    pred_sigs = []
    gt_sigs = []
    for i in range(batch_size):
        boxes, labels, sco, is_found = encoder.decode(loc_preds[i],
                                                      cls_preds[i],
                                                      input_length,
                                                      CLS_THRESH=0.425,
                                                      NMS_THRESH=0.5)
        if is_found:
            boxes = boxes.ceil()
            xmin = boxes[:, 0].clamp(min=1)
            xmax = boxes[:, 1].clamp(max=input_length - 1)

            pred_sig = box_to_sig_generator(xmin,
                                            xmax,
                                            labels,
                                            input_length,
                                            background=False)

        else:
            pred_sig = torch.zeros(1, 4, input_length)
        if ground_truth:
            gt_boxes, gt_labels, gt_sco, gt_is_found = encoder.decode(
                loc_targets[i], one_hot_embedding(cls_targets[i], 4),
                input_length)
            gt_sig = box_to_sig_generator(gt_boxes[:, 0],
                                          gt_boxes[:, 1],
                                          gt_labels,
                                          input_length,
                                          background=False)
            gt_sigs.append(gt_sig)
        pred_sigs.append(pred_sig)
    pred_signals = torch.cat(pred_sigs, 0)
    pred_onset_offset = onset_offset_generator(pred_signals)
    plot = None
    if visual:
        if ground_truth is not None:
            for i in range(batch_size):
                plot = predict_plotter(x[i][0],
                                       pred_signals[i],
                                       ground_truth[i],
                                       name=str(i))
        else:
            for i in range(batch_size):
                plot = predict_plotter(x[i][0], pred_signals[i], name=str(i))

    if ground_truth:
        gt_signals = torch.cat(gt_sigs, 0)
        gt_onset_offset = onset_offset_generator(gt_signals)
        TP, FP, FN = validation_accuracy(pred_onset_offset, gt_onset_offset)

    intervals = validation_duration_accuracy(pred_onset_offset[:, 1:, :])
    return plot, intervals, pred_signals
def eval_retinanet(model, dataloader):
    """
    the evaluation function that can be used during RetinaNet training.

    Args:
        model:      (nn.Module) RetinaNet module variable
        dataloader: (DataLoader) validation dataloader
        
    Returns:
        Se:     (float) TP / (TP + FN)
        PPV:    (float) TP / (TP + FP)
        F1:     (float) 2 * Se * PPV / (Se + PPV)
    """
    input_length = 3968
    model.eval()

    pred_sigs = []
    gt_sigs = []
    sigs = []
    for batch_idx, (inputs, loc_targets, cls_targets, gt_boxes, gt_labels,
                    gt_peaks) in enumerate(dataloader):
        batch_size = inputs.size(0)
        inputs = torch.autograd.Variable(inputs.cuda())
        loc_targets = torch.autograd.Variable(loc_targets.cuda())
        cls_targets = torch.autograd.Variable(cls_targets.cuda())
        inputs = inputs.unsqueeze(1)
        sigs.append(inputs)

        loc_preds, cls_preds = model(inputs)

        loc_preds = loc_preds.data.squeeze().type(
            torch.FloatTensor)  # sized [#anchors * 3, 2]
        cls_preds = cls_preds.data.squeeze().type(
            torch.FloatTensor)  # sized [#ahchors * 3, 3]

        loc_targets = loc_targets.data.squeeze().type(torch.FloatTensor)
        cls_targets = cls_targets.data.squeeze().type(torch.LongTensor)

        # decoder only process data 1 by 1.
        encoder = DataEncoder()
        for i in range(batch_size):
            boxes, labels, sco, is_found = encoder.decode(
                loc_preds[i], cls_preds[i], input_length)

            #ground truth decode using another method
            gt_boxes_tensor = torch.tensor(gt_boxes[i])
            gt_labels_tensor = torch.tensor(gt_labels[i])
            xmin = gt_boxes_tensor[:, 0].clamp(min=1)
            xmax = gt_boxes_tensor[:, 1].clamp(max=input_length - 1)
            gt_sig = box_to_sig_generator(xmin,
                                          xmax,
                                          gt_labels_tensor,
                                          input_length,
                                          background=False)

            if is_found:
                boxes = boxes.ceil()
                xmin = boxes[:, 0].clamp(min=1)
                xmax = boxes[:, 1].clamp(max=input_length - 1)

                # there is no background anchor on predict labels
                pred_sig = box_to_sig_generator(xmin,
                                                xmax,
                                                labels,
                                                input_length,
                                                background=False)
            else:
                pred_sig = torch.zeros(1, 4, input_length)

            pred_sigs.append(pred_sig)
            gt_sigs.append(gt_sig)
    sigs = torch.cat(sigs, 0)
    pred_signals = torch.cat(pred_sigs, 0)
    gt_signals = torch.cat(gt_sigs, 0)
    plot = predict_plotter(sigs[0][0], pred_signals[0], gt_signals[0])
    #wandb.log({"visualization": plot})
    pred_onset_offset = onset_offset_generator(pred_signals)
    gt_onset_offset = onset_offset_generator(gt_signals)
    TP, FP, FN = validation_accuracy(pred_onset_offset, gt_onset_offset)

    Se = TP / (TP + FN)
    PPV = TP / (TP + FP)
    F1 = 2 * Se * PPV / (Se + PPV)

    print("Se: {} PPV: {} F1 score: {}".format(Se, PPV, F1))
    wandb.log({"Se": Se, "PPV": PPV, "F1": F1})

    return Se, PPV, F1
def eval_unet(net, loader, device):
    """
    the evaluation function that can be used during UNet training.

    Args:
        net:    (nn.Module) UNet module variable
        loader: (DataLoader) validation dataloader
        device: (str) using GPU or CPU
    Returns:
        average loss:   (float) average loss within validation set
        pointwise acc:  (float) pointwise evaluation
        Se:             (float) TP / (TP + FN)
        PPV:            (float) TP / (TP + FP)
        F1:             (float) 2 * Se * PPV / (Se + PPV)
        ret:            (list of dict) see validation_duration_accuracy for more detail
    """
    net.eval()
    n_val = len(loader)
    tot = 0
    correct = 0
    total = 0

    TP = 0
    FP = 0
    FN = 0

    with tqdm(total=n_val,
              desc='Validation round',
              unit='batch',
              leave=False,
              ncols=100) as pbar:
        for batch in loader:
            x, ground_truth = batch[0], batch[1]
            x = x.to(device, dtype=torch.float32)
            ground_truth = ground_truth.to(device, dtype=torch.float32)

            with torch.no_grad():
                pred = net(x)

            tot += F.binary_cross_entropy_with_logits(pred,
                                                      ground_truth).item()
            # (batch_size, channels, data)
            pred_ans = F.one_hot(pred.argmax(1),
                                 num_classes=4).permute(0, 2, 1)
            correct += pred_ans.eq(ground_truth).sum().item()
            total += ground_truth.shape[0] * ground_truth.shape[
                1] * ground_truth.shape[2]

            # only use first 3 channels because first three channels will produce all onsets/offsets
            pred_onset_offset = onset_offset_generator(pred_ans[:, :3, :])
            gt_onset_offset = onset_offset_generator(ground_truth[:, :3, :])
            tp, fp, fn = validation_accuracy(pred_onset_offset,
                                             gt_onset_offset)
            ret = validation_duration_accuracy(pred_onset_offset)
            TP += tp
            FP += fp
            FN += fn

            pbar.update()

    Se = TP / (TP + FN)
    PPV = TP / (TP + FP)
    F1 = 2 * Se * PPV / (Se + PPV)

    return tot / n_val, correct / total, Se, PPV, F1, ret