예제 #1
0
def extract_super_pixels(mat_adj=None,
                         test_left=None,
                         test_right=None,
                         mat_cat=None,
                         num_neighbors=8,
                         cor_choice='mean',
                         connectivity=None,
                         min_pixels=50,
                         image=None,
                         plot=False,
                         use_mean_image=False):
    if image is None:
        if mat_cat is None:
            mat_cat = torch.cat([m[test_left:test_right] for m in mat_adj],
                                dim=0)
        cor_global = neighbor_cor(mat_cat,
                                  neighbors=num_neighbors,
                                  plot=plot,
                                  choice=cor_choice,
                                  title='correlation map')
        if use_mean_image:
            cor_global = mat_cat.mean(0) * cor_global
            cor_global = cor_global / cor_global.max()
        image = cor_global.detach().cpu().numpy()
    else:
        cor_global = image  # for backward compatibility
    if isinstance(image, torch.Tensor):
        image = image.detach().cpu().numpy()
    label_image, regions = get_label_image(image,
                                           min_pixels=min_pixels,
                                           connectivity=connectivity,
                                           plot=False)
    if plot:
        plot_image_label_overlay(image, label_image)
    return cor_global, label_image, regions
예제 #2
0
def refine_one_label(submat, min_pixels=50, return_traces=False, percentile=50):
    soft_attention = attention_map(submat)
    label_image, regions = get_label_image(soft_attention, min_pixels=min_pixels)
    if return_traces:
        submats, traces = extract_traces(submat, softmask=soft_attention, label_image=label_image, regions=regions, 
                                         percentile=percentile)
        return submats, traces, soft_attention, label_image, regions
    else:
        return label_image
예제 #3
0
def get_high_conf_mask(cor_map,
                       low_percentile=25,
                       high_percentile=5,
                       min_cor=0.1,
                       min_pixels=20,
                       exclude_boundary_width=2):
    """Return a high-confidence mask
    """
    label_image, regions = get_label_image(cor_map, min_pixels=min_pixels)
    fg_mask = ((cor_map >= np.percentile(cor_map.cpu(), high_percentile)) &
               (cor_map.new_tensor(label_image) > 0) & (cor_map > min_cor))
    bg_mask = ((cor_map <= np.percentile(cor_map.cpu(), low_percentile)) &
               (cor_map.new_tensor(label_image) == 0) & (cor_map < min_cor))
    mask = cor_map.new_full(cor_map.size(), -1)
    mask[bg_mask] = 0
    mask[fg_mask] = 1
    mask[:exclude_boundary_width] = -1
    mask[-exclude_boundary_width:] = -1
    mask[:, :exclude_boundary_width] = -1
    mask[:, -exclude_boundary_width:] = -1
    return mask
예제 #4
0
def refine_segmentation(submat,
                        label_mask,
                        label_image,
                        label_idx,
                        plot=False,
                        figsize=(15, 10)):
    num_pcs, u, s, v = svd(submat[:, label_mask.bool()], plot=plot)
    if num_pcs >= 2:
        A, B, loss_history = step_decompose(submat.reshape(submat.size(0), -1),
                                            num_components=2 * num_pcs)
        mask = B[0].reshape(submat.size(1), submat.size(2))
        X = mask.reshape(-1).unsqueeze(-1)
        kmeans = KMeans(n_clusters=2 * num_pcs,
                        random_state=0).fit(X.detach().cpu())
        split_label_segmentation = kmeans.labels_.reshape(mask.shape)
        label_image, regions = get_label_image(
            image=None,
            label_image=label_image,
            split_label=label_idx - 1,
            split_label_segmentation=split_label_segmentation,
            plot=plot,
            figsize=figsize)
예제 #5
0
def basic_segmentation(mat, min_thresh=0.05, min_pixels=50, select_frames=True, show=True, median_detrend=False, 
                       fft=False, fft_max_freq=200):
    """Basic segmentation 
    Args:
        mat: torch.Tensor with shape (nframe, nrow, ncol) or (n_experiments, nframe, nrow, ncol)
        min_thresh: float, used by get_label_image
        min_pixels: int, used by get_label_image
        select_frames: default True, only used when mat.ndim==4, selecting only frames with average mean beyond an otsu threshold
        
    Returns:
        cor_map: torch.Tensor with shape (nrow, ncol)
        label_image: torch.Tensor with shape (nrow, ncol); 0 is background; label==i is mask for label i for i >= 1
        regions: object returned by regionprops
        
    """
    dtype = mat.dtype
    if dtype == torch.float16:
        mat = mat.float()
    if median_detrend:
        mat = mat - get_local_median(mat, window_size=50, dim=-3)
    if fft:
        if mat.ndim == 3:
            mat = torch.rfft(mat.transpose(0, 2), signal_ndim=1, normalized=True)[..., :fft_max_freq, :].reshape(
                mat.size(2), mat.size(1), -1).transpose(0, 2)
        elif mat.ndim == 4:
            mat = torch.rfft(mat.transpose(1, 3), signal_ndim=1, normalized=True)[..., :fft_max_freq, :].reshape(
                mat.size(0), mat.size(3), mat.size(2), -1).transpose(1, 3)
    if mat.ndim == 3:
        cor_map = get_cor_map(mat)
    elif mat.ndim == 4:
        cor_map = get_cor_map_4d(mat, select_frames=select_frames, top_cor_map_percentage=20, padding=2, topk=5, shift_times=[0, 1, 2], 
                                 return_all=False, plot=False)
    label_image, regions = get_label_image(cor_map, min_thresh=min_thresh, min_pixels=min_pixels)
    label_image = torch.from_numpy(label_image).to(mat.device)
    if show:
        imshow(cor_map)
        plot_image_label_overlay(cor_map, label_image=label_image, regions=regions)
    return cor_map, label_image, regions