def normalize_staining_torch(img: torch.tensor, io=240, alpha=1, beta=0.15):
    """
    Normalize staining appearence of H&E stained images,
    Original repository:
    https://github.com/schaugf/HEnorm_python/blob/master/normalizeStaining.py

    Input:
        I: RGB input image as tensor
        Io: (optional) transmitted light intensity

    Output:
        Inorm: normalized image
        H: hematoxylin image
        E: eosin image

    Reference:
        A method for normalizing histology slides for quantitative analysis. M.
        Macenko et al., ISBI 2009
    """
    HERef = torch.tensor([[0.5626, 0.2159],
                          [0.7201, 0.8012],
                          [0.4062, 0.5581]])

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    HERef = HERef.to(device)

    maxCRef = torch.tensor([1.9705, 1.0308])

    # define height and width of image
    h, w, c = img.shape

    # reshape image
    img = img.reshape((-1, 3))

    # calculate optical density
    OD = -torch.log((img.double() + 1) / io)

    # remove transparent pixels
    ODhat = OD[~torch.any(OD < beta, axis=1)]
    if len(ODhat) != 0:
        cov = covar(ODhat.T)
    else:
        img = img.reshape((h, w, c))
        return img

    if not isnan(cov).any() and \
            not torch.isinf(cov).any():
        # compute eigenvectors
        eigvals, eigvecs = torch.symeig(cov, eigenvectors=True)
    else:
        img = img.reshape((h, w, c))
        return img

    # eigvecs *= -1

    # project on the plane spanned by the eigenvectors corresponding to the two
    # largest eigenvalues
    That = ODhat.mm(eigvecs[:, 1:3])
    phi = torch.atan2(That[:, 1], That[:, 0])

    minPhi = percentile(phi, alpha)
    maxPhi = percentile(phi, 100 - alpha)

    vMin = eigvecs[:, 1:3].mm(torch.tensor([[torch.cos(minPhi), torch.sin(minPhi)]]).T.to(device))
    vMax = eigvecs[:, 1:3].mm(torch.tensor([[torch.cos(maxPhi), torch.sin(maxPhi)]]).T.to(device))

    # a heuristic to make the vector corresponding to hematoxylin first and the
    # one corresponding to eosin second
    if vMin[0] > vMax[0]:
        # HE = torch.array((vMin[:, 0], vMax[:, 0])).T
        HE = torch.cat((vMin[:, 0].reshape(1, -1), vMax[:, 0].reshape(1, -1)), 0).T
    else:
        # HE = np.array((vMax[:, 0], vMin[:, 0])).T
        HE = torch.cat((vMax[:, 0].reshape(1, -1), vMin[:, 0].reshape(1, -1)), 0).T

    # rows correspond to channels (RGB), columns to OD values
    Y = torch.reshape(OD, (-1, 3)).T

    # determine concentrations of the individual stains
    C = torch.pinverse(HE).mm(Y)

    # normalize stain concentrations
    maxC = torch.tensor([percentile(C[0, :], 99), percentile(C[1, :], 99)])
    tmp = torch.div(maxC, maxCRef)
    C2 = torch.div(C, tmp[:, np.newaxis].to(device))

    # recreate the image using reference mixing matrix
    Inorm = io * torch.exp(-HERef.double().mm(C2))
    Inorm[Inorm > 255] = 254
    Inorm = torch.reshape(Inorm.T, (h, w, 3)).to(torch.uint8)

    return Inorm
Example #2
0
def computeNonRigidTransformation(
    source    :torch.tensor,
    target    :torch.tensor
):
    if len(source.shape) == 2:
        source = source.unsqueeze(0)

    if len(target.shape) == 2:
        target = target.unsqueeze(0)

    b1, c1, N1 = source.shape
    b2, c2, N2 = target.shape

    assert b1==b2, "Batch sizes differ" #TODO: maybe change later so that it could support b1=K, b2=1
    assert c1==c2, "Inputs channels differ"
    assert N1==N2, "Number of samples differ"

    b, c, N = b1, c1, N1

    if source.dtype != torch.double:
        source = source.double()

    if target.dtype != torch.double:
        target = target.double()

    device = source.device

    centroid_source = torch.mean(source, dim = 2).unsqueeze(-1)    #bxcx1
    centroid_target = torch.mean(target, dim = 2).unsqueeze(-1)    #bxcx1

    H_source = source - centroid_source
    H_target = target - centroid_target

    variance_source = torch.sum(H_source**2)
    variance_source = torch.einsum('...bij->...b',H_source**2)

    H = torch.einsum('...in,...jn->...ijn',H_source,H_target)
    H = torch.sum(H, dim = -1)


    list_R, list_t, list_scale = [], [], []

    # care https://github.com/pytorch/pytorch/issues/16076#issuecomment-477755364
    for _b in range(b):
        #assert torch.abs(torch.det(H[_b])).item() > 1.0e-15, "Seems that H matrix is singular"
        U,S,V = torch.svd(H[_b])
        R = torch.matmul(V, U.t())

        Z = torch.eye(R.shape[0]).double().to(device)
        Z[-1,-1] *= torch.sign(torch.det(R))

        R = torch.mm(V,torch.mm(Z,U.t()))

        scale = torch.trace(torch.mm(R,H[_b])) / variance_source[_b]

        list_R.append(R.unsqueeze(0))
        list_scale.append(scale.unsqueeze(0).unsqueeze(-1))


    R = torch.cat(list_R, dim = 0)
    scale = torch.cat(list_scale, dim = 0).unsqueeze(-1)
    t = -torch.bmm(R,centroid_source) + centroid_target
    return R, t, scale