def extract_super_pixels(mat_adj=None, test_left=None, test_right=None, mat_cat=None, num_neighbors=8, cor_choice='mean', connectivity=None, min_pixels=50, image=None, plot=False, use_mean_image=False): if image is None: if mat_cat is None: mat_cat = torch.cat([m[test_left:test_right] for m in mat_adj], dim=0) cor_global = neighbor_cor(mat_cat, neighbors=num_neighbors, plot=plot, choice=cor_choice, title='correlation map') if use_mean_image: cor_global = mat_cat.mean(0) * cor_global cor_global = cor_global / cor_global.max() image = cor_global.detach().cpu().numpy() else: cor_global = image # for backward compatibility if isinstance(image, torch.Tensor): image = image.detach().cpu().numpy() label_image, regions = get_label_image(image, min_pixels=min_pixels, connectivity=connectivity, plot=False) if plot: plot_image_label_overlay(image, label_image) return cor_global, label_image, regions
def refine_one_label(submat, min_pixels=50, return_traces=False, percentile=50): soft_attention = attention_map(submat) label_image, regions = get_label_image(soft_attention, min_pixels=min_pixels) if return_traces: submats, traces = extract_traces(submat, softmask=soft_attention, label_image=label_image, regions=regions, percentile=percentile) return submats, traces, soft_attention, label_image, regions else: return label_image
def get_high_conf_mask(cor_map, low_percentile=25, high_percentile=5, min_cor=0.1, min_pixels=20, exclude_boundary_width=2): """Return a high-confidence mask """ label_image, regions = get_label_image(cor_map, min_pixels=min_pixels) fg_mask = ((cor_map >= np.percentile(cor_map.cpu(), high_percentile)) & (cor_map.new_tensor(label_image) > 0) & (cor_map > min_cor)) bg_mask = ((cor_map <= np.percentile(cor_map.cpu(), low_percentile)) & (cor_map.new_tensor(label_image) == 0) & (cor_map < min_cor)) mask = cor_map.new_full(cor_map.size(), -1) mask[bg_mask] = 0 mask[fg_mask] = 1 mask[:exclude_boundary_width] = -1 mask[-exclude_boundary_width:] = -1 mask[:, :exclude_boundary_width] = -1 mask[:, -exclude_boundary_width:] = -1 return mask
def refine_segmentation(submat, label_mask, label_image, label_idx, plot=False, figsize=(15, 10)): num_pcs, u, s, v = svd(submat[:, label_mask.bool()], plot=plot) if num_pcs >= 2: A, B, loss_history = step_decompose(submat.reshape(submat.size(0), -1), num_components=2 * num_pcs) mask = B[0].reshape(submat.size(1), submat.size(2)) X = mask.reshape(-1).unsqueeze(-1) kmeans = KMeans(n_clusters=2 * num_pcs, random_state=0).fit(X.detach().cpu()) split_label_segmentation = kmeans.labels_.reshape(mask.shape) label_image, regions = get_label_image( image=None, label_image=label_image, split_label=label_idx - 1, split_label_segmentation=split_label_segmentation, plot=plot, figsize=figsize)
def basic_segmentation(mat, min_thresh=0.05, min_pixels=50, select_frames=True, show=True, median_detrend=False, fft=False, fft_max_freq=200): """Basic segmentation Args: mat: torch.Tensor with shape (nframe, nrow, ncol) or (n_experiments, nframe, nrow, ncol) min_thresh: float, used by get_label_image min_pixels: int, used by get_label_image select_frames: default True, only used when mat.ndim==4, selecting only frames with average mean beyond an otsu threshold Returns: cor_map: torch.Tensor with shape (nrow, ncol) label_image: torch.Tensor with shape (nrow, ncol); 0 is background; label==i is mask for label i for i >= 1 regions: object returned by regionprops """ dtype = mat.dtype if dtype == torch.float16: mat = mat.float() if median_detrend: mat = mat - get_local_median(mat, window_size=50, dim=-3) if fft: if mat.ndim == 3: mat = torch.rfft(mat.transpose(0, 2), signal_ndim=1, normalized=True)[..., :fft_max_freq, :].reshape( mat.size(2), mat.size(1), -1).transpose(0, 2) elif mat.ndim == 4: mat = torch.rfft(mat.transpose(1, 3), signal_ndim=1, normalized=True)[..., :fft_max_freq, :].reshape( mat.size(0), mat.size(3), mat.size(2), -1).transpose(1, 3) if mat.ndim == 3: cor_map = get_cor_map(mat) elif mat.ndim == 4: cor_map = get_cor_map_4d(mat, select_frames=select_frames, top_cor_map_percentage=20, padding=2, topk=5, shift_times=[0, 1, 2], return_all=False, plot=False) label_image, regions = get_label_image(cor_map, min_thresh=min_thresh, min_pixels=min_pixels) label_image = torch.from_numpy(label_image).to(mat.device) if show: imshow(cor_map) plot_image_label_overlay(cor_map, label_image=label_image, regions=regions) return cor_map, label_image, regions