def get_um_loss_gradient(mst, dist, gt_seg, alpha): '''Compute the ultra-metric loss gradient given an MST and segmentation. Args: mst (Tensor, shape ``(3, n-1)``): u, v indices and distance of edges of the MST spanning n nodes. dist (Tensor, shape ``(n-1)``): The distances of the edges. This argument will be ignored, it is used only to communicate to tensorflow that there is a dependency on distances. The distances actually used are the ones in parameter ``mst``. gt_seg (Tensor, arbitrary shape): The label of each node. Will be flattened. The indices in mst should be valid indices into this array. alpha (Tensor, single float): The margin value of the quadrupel loss. Returns: A Tensor containing the gradient on the distances. ''' # We don't use 'dist' here, it is already contained in the mst. It is # passed here just so that tensorflow knows there is dependecy to the # ouput. (_, gradient, _, _, _, _) = mala.um_loss(mst, gt_seg.flatten(), alpha) return gradient.astype(np.float32)
def get_um_loss(mst, dist, gt_seg, alpha): '''Compute the ultra-metric loss given an MST and segmentation. Args: mst (Tensor, shape ``(3, n-1)``): u, v indices and distance of edges of the MST spanning n nodes. dist (Tensor, shape ``(n-1)``): The distances of the edges. This argument will be ignored, it is used only to communicate to tensorflow that there is a dependency on distances. The distances actually used are the ones in parameter ``mst``. gt_seg (Tensor, arbitrary shape): The label of each node. Will be flattened. The indices in mst should be valid indices into this array. alpha (Tensor, single float): The margin value of the quadrupel loss. Returns: A tuple:: (loss, ratio_pos, ratio_neg) Except for ``loss``, each entry is a tensor of shape ``(n-1,)``, corresponding to the edges in the MST. ``ratio_pos`` and ``ratio_neg`` are the ratio of positive and negative pairs that share an edge, of the total number of positive and negative pairs. ''' # We don't use 'dist' here, it is already contained in the mst. It is # passed here just so that tensorflow knows there is dependecy to the # ouput. (loss, _, ratio_pos, ratio_neg, num_pairs_pos, num_pairs_neg) = mala.um_loss(mst, gt_seg.flatten(), alpha) return (np.float32(loss), ratio_pos.astype(np.float32), ratio_neg.astype(np.float32), np.float32(num_pairs_pos), np.float32(num_pairs_neg))