def v_dice_coeff(P, L, use_argmax=False, one_hot_labels=False): """ Calculates V-Dice for give predictions and labels. WARNING! THIS IMPLIES SEGMENTATION CONTEXT. Parameters ---------- P : np.ndarray Predictions of a segmentator. Array of shape [batch_sz, W, H, num_classes]. L : np.ndarray Labels for the segmentator. Array of shape [batch_sz, W, H] use_argmax : bool Converts the segmentator's predictions to one-hot format. Example: [0.4, 0.1, 0.5] -> [0., 0., 1.] one_hot_labels : bool Set to True if the labels (`L`) are already one-hot encoded, i.e. have the same shape as `P`. """ # P has shape [batch_sz, W, H, num_classes] # L has shape [batch_sz, W, H] # RESHAPE TENSORS AND ONE-HOT LABELS # P -> [batch_sz, num_samples, num_classes] batch_sz = len(P) num_samples = P.shape[1] * P.shape[2] num_classes = P.shape[-1] if use_argmax: P = P.argmax(axis=3) P = P.reshape(batch_sz * num_samples) P = one_hot(P, depth=num_classes) P = P.reshape(batch_sz, num_samples, num_classes) if not one_hot_labels: # L -> [batch_sz*num_samples] -> [batch_sz*num_samples, num_classes] -> [batch_sz, num_samples, num_classes] L = L.reshape(batch_sz * num_samples) L = one_hot(L, depth=num_classes) L = L.reshape(batch_sz, num_samples, num_classes) # P has shape [batch_sz, num_samples, num_classes] # L has shape [batch_sz, num_samples, num_classes] R = P * L nums = R.sum(axis=1) P2 = P * P P2vec = P2.sum(axis=1) Lvec = L.sum(axis=1) dens = P2vec + Lvec dices_b = (2 * nums + EPSILON) / (dens + EPSILON) dices = dices_b.mean(axis=0) return dices.mean(), dices
def categorical_dice_coeff(P, L, use_argmax=False, ind_norm=True): """ Calculates V-Dice for give predictions and labels. WARNING! THIS IMPLIES SEGMENTATION CONTEXT. Parameters ---------- P : np.ndarray Predictions of a segmentator. Array of shape [batch_sz, W, H, num_classes]. L : np.ndarray Labels for the segmentator. Array of shape [batch_sz, W, H] use_argmax : bool Converts the segmentator's predictions to one-hot format. Example: [0.4, 0.1, 0.5] -> [0., 0., 1.] ind_norm : bool Normalize each dice separately. Useful in case some classes don't appear on some images. """ batch_sz = len(P) L = np.asarray(L) P = np.asarray(P) num_classes = P.shape[-1] if use_argmax: P = P.argmax(axis=3) P = P.reshape(-1) P = one_hot(P, depth=num_classes) P = P.reshape(batch_sz, -1, num_classes) L = L.reshape(batch_sz, -1) class_dices = np.zeros(num_classes) class_counts = np.zeros( num_classes) + EPSILON # Smoothing to avoid division by zero for i in range(batch_sz): sample_actual = L[i] sample_pred = P[i] for j in range(num_classes): sub_actual = (sample_actual[:] == j).astype(np.int32) sub_confs = sample_pred[:, j] if np.sum(sub_actual) == 0 and np.sum(sub_confs) == 0: continue class_dices[j] += binary_dice(sub_confs, sub_actual) class_counts[j] += 1 v_dice, dices = class_dices.mean() / batch_sz, class_dices / batch_sz if ind_norm: v_dice, dices = (class_dices / class_counts).mean(), class_dices / class_counts return v_dice, dices