コード例 #1
0
def basic_segmentation(mat, min_thresh=0.05, min_pixels=50, select_frames=True, show=True, median_detrend=False, 
                       fft=False, fft_max_freq=200):
    """Basic segmentation 
    Args:
        mat: torch.Tensor with shape (nframe, nrow, ncol) or (n_experiments, nframe, nrow, ncol)
        min_thresh: float, used by get_label_image
        min_pixels: int, used by get_label_image
        select_frames: default True, only used when mat.ndim==4, selecting only frames with average mean beyond an otsu threshold
        
    Returns:
        cor_map: torch.Tensor with shape (nrow, ncol)
        label_image: torch.Tensor with shape (nrow, ncol); 0 is background; label==i is mask for label i for i >= 1
        regions: object returned by regionprops
        
    """
    dtype = mat.dtype
    if dtype == torch.float16:
        mat = mat.float()
    if median_detrend:
        mat = mat - get_local_median(mat, window_size=50, dim=-3)
    if fft:
        if mat.ndim == 3:
            mat = torch.rfft(mat.transpose(0, 2), signal_ndim=1, normalized=True)[..., :fft_max_freq, :].reshape(
                mat.size(2), mat.size(1), -1).transpose(0, 2)
        elif mat.ndim == 4:
            mat = torch.rfft(mat.transpose(1, 3), signal_ndim=1, normalized=True)[..., :fft_max_freq, :].reshape(
                mat.size(0), mat.size(3), mat.size(2), -1).transpose(1, 3)
    if mat.ndim == 3:
        cor_map = get_cor_map(mat)
    elif mat.ndim == 4:
        cor_map = get_cor_map_4d(mat, select_frames=select_frames, top_cor_map_percentage=20, padding=2, topk=5, shift_times=[0, 1, 2], 
                                 return_all=False, plot=False)
    label_image, regions = get_label_image(cor_map, min_thresh=min_thresh, min_pixels=min_pixels)
    label_image = torch.from_numpy(label_image).to(mat.device)
    if show:
        imshow(cor_map)
        plot_image_label_overlay(cor_map, label_image=label_image, regions=regions)
    return cor_map, label_image, regions
コード例 #2
0
def complete_fast_algorithm(content_path, style_path, img_size=512):
    """
    returns the instance of image which represents the result
    """
    import matplotlib.pyplot as plt
    from setup import load_img
    from visualization import imshow

    print('Loading Images ---')
    content_image = load_img(content_path, img_size)
    style_image = load_img(style_path, img_size)

    print('Executing style transfer, this could take a moment\nPlease wait...')
    result_image = fast_style_transfer(content_image, style_image)

    print('plotting images -->')
    plt.subplot(2, 2, 1)
    imshow(content_image, 'Content image')

    plt.subplot(2, 2, 2)
    imshow(style_image, 'Style image')

    return result_image
コード例 #3
0
def detrend_high_magnification(mat, skip_segments=1, num_segments=6, period=500, train_size_left=0, train_size_right=350, 
                               linear_order=3, plot=False, signal_start=0, signal_end=100, filepath=None, size=(-1, 180, 300), 
                               device=torch.device('cuda'), start0=None, end0=None, return_mat=False, **kwargs):
    if mat is None:
        mat = load_file(filepath=filepath, size=size, dtype=np.uint16, device=device)
    L, nrow, ncol = mat.size()
    if period == 'unknown':
        period = L
    if signal_end == 'period':
        signal_end = period
    train_idx = ([range(skip_segments*period)] +
                 [range(i*period, train_size_left+i*period) for i in range(skip_segments, num_segments)] + 
                 [range((i+1)*period - train_size_right, (i+1)*period) for i in range(skip_segments, num_segments)])
    train_idx = functools.reduce(lambda x,y: list(x)+list(y), train_idx)
    input_aug = torch.linspace(-2, 2, L, device=device)
    beta, trend = linear_regression(X=input_aug[train_idx], Y=mat.reshape(L, -1)[train_idx], order=linear_order, X_test=input_aug)
    mat_adj = mat - trend.reshape(L, nrow, ncol)
    if plot:
        frame_mean = mat.mean(-1).mean(-1)
        plt.figure(figsize=(20, 10))
        plt.plot(frame_mean.cpu(), 'o--', linewidth=1, markersize=2)
        for i in range(num_segments):
            plt.axvline(i*period+signal_start, color='g', linestyle='-.')
            plt.axvline(i*period+signal_end, color='g', linestyle='-.')
        plt.title('Frame mean intensity')
        plt.show()
        imshow(mat.mean(0), title='Mean intensity')
        imshow(trend.mean(0).reshape(nrow, ncol), title='Trend')
        imshow(mat_adj.mean(0).reshape(nrow, ncol), title='Detrended')

        cor = neighbor_cor(mat, neighbors=8, plot=True, choice='max', title='cor mat')
        plot_hist(cor, show=True)
        cor = neighbor_cor(trend.reshape(-1, nrow, ncol), neighbors=8, plot=True, choice='max', title='cor trend')
        plot_hist(cor, show=True)
        cor = neighbor_cor(mat_adj.reshape(-1, nrow, ncol), neighbors=8, plot=True, choice='max', title='cor mat_adj')
        plot_hist(cor, show=True)
    
        plot_mean_intensity(mat, detrended=mat_adj, plot_detrended=True, plot_segments=True, num_frames=L, period=period, 
                            signal_length=signal_end-signal_start)
    if start0 is not None and end0 is not None:
        mat_adj = [mat_adj[s:e] for s, e in zip(start0, end0)]
    if return_mat:
        return mat, mat_adj
    else:
        return mat_adj
コード例 #4
0
def entire_pipeline(bucket, result_folder='results', bin_files=None, delete_local_data=True, 
                    apply_spectral_clustering=False, spectral_soft_threshold=True, spectral_cor_threshold=None,
                    denoise=False, denoise_model_config=None, denoise_loss_threshold=0, denoise_num_epochs=12, denoise_num_iters=600, 
                    denoise_batch_size=2, denoise_batch_size_eval=4,
                    display=False, verbose=False, half_precision=False, device=torch.device('cuda')):
    """Entire pipeline to process OPP voltage imaging data
    Args:
        bucket: google bucket folder containing .bin files and metadata .json files
        result_folder: default 'results'
        bin_files: default None, process all the .bin files with metadata in the bucket; 
            otherwise only process those specified in bin_files; bin_files can be a list of .bin files or a single .bin file
            note .bin suffix should not be included as filename
        delete_local_data: if True, will delete all intermediate data on the local disk
        apply_spectral_clustering: if True, apply spectral clustering to get fine-grained segmentation
        spectral_soft_threshold: if True, automatically determine when to split and when to stop
        spectral_cor_threshold: hard threshold between 0 and 1 used to determine when to split a cluster; 
            only used when spectral_soft_threshold is False
        denoise: if True, run denoise pipeline; currently, often encounter GPU memory errors when running on multiple .bin files
        denoise_model_config: default None; if not None, provide a json file path
        denoise_loss_threshold: default 0, if bigger than 0, then training will automatically stop after loss is below this threshold
        denoise_num_epochs: num of epochs to train the denoising model
        denosie_num_iters: num of iterations every epoch
        denoise_batch_size: default 2, can be increased to a larger integer if GPU memory is sufficient
        denoise_batch_size_eval: default 4
    
    Returns:
        After the call, the results will be uploaded to the cloud automatically
    """
    def save_results(mat, softmask, label_image, regions, save_folder, file, segmentation_name='basic_segmentation', display=False, verbose=False):
        submats, traces = extract_traces(mat, softmask=softmask, label_image=label_image, regions=regions, median_detrend=False)
        plot_image_label_overlay(softmask, label_image=label_image, regions=regions, save_file=f'{save_folder}/label_image__{segmentation_name}.png', 
                                 title=f'{file}: {label_image.max()} neurons detected', display=display)
        if len(traces) > 0:
            if verbose:
                print(f'{len(traces)} neuron detected in {file}.bin')
            np.save(f'{save_folder}/label_image__{segmentation_name}.npy', label_image.cpu().numpy())
            np.save(f'{save_folder}/traces__{segmentation_name}.npy', traces.cpu().numpy())
    
    def save_figures(segmentation_name='basic_segmentation', figsize=(20, 20), bounding_box=True):
        for file in bin_files:
            save_folder = f'{data_folder}/{result_folder}/{file}'
            if os.path.exists(f'{save_folder}/label_image__{segmentation_name}.npy'):
                cor_map = np.load(f'{save_folder}/cor_map.npy')
                label_image = np.load(f'{save_folder}/label_image__{segmentation_name}.npy')
                traces = np.load(f'{save_folder}/traces__{segmentation_name}.npy')
                image_label_overlay = label2rgb(label_image, image=cor_map)
                regions = regionprops(label_image)
                fig_folder = f'{save_folder}/figs/{segmentation_name}'
                if not os.path.exists(fig_folder):
                    os.makedirs(fig_folder)
                for sel_idx in range(label_image.max()):
                    fig, ax = plt.subplots(2, figsize=figsize)
                    ax[0].imshow(image_label_overlay)
                    if bounding_box:
                        region = regions[sel_idx]
                        minr, minc, maxr, maxc = region.bbox
                        rect = mpatches.Rectangle((minc, minr), maxc - minc, maxr - minr,
                                                  fill=False, edgecolor='red', linewidth=2)
                        ax[0].add_patch(rect)
                        ax[0].text(minc-3, minr-1, sel_idx+1, color='r')
                    ax[0].set_axis_off()
                    ax[0].set_title(f'Neuron {sel_idx+1} segmentation')
                    ax[1].plot(traces[sel_idx])
                    ax[1].set_title(f'Neuron {sel_idx+1} trace')
                    plt.tight_layout()
                    plt.savefig(f'{fig_folder}/{sel_idx+1}.png')
                    plt.close()
                imgs = [f'{fig_folder}/{i+1}.png' for i in range(label_image.max())]
                save_gif_file(imgs, save_path=f'{fig_folder}/{label_image.max()}_neurons.gif')
                fig, ax = plt.subplots(figsize=figsize)
                for i in range(len(traces)):
                    ax.plot(traces[i], label=i+1, c=good_colors[i%len(good_colors)])
                    ax.text(-len(traces[i])*0.02, traces[i, :10].mean(), i+1, c=good_colors[i%len(good_colors)])
                plt.legend()
                plt.savefig(f'{fig_folder}/{label_image.max()}_traces.png')
                plt.close()
    
    bin_files, meta_data, data_folder = prepare_data(bucket=bucket, bin_files=bin_files, result_folder=result_folder, 
                                                     data_folder_prefix='.', verbose=verbose)
    
    good_colors = get_good_colors()
    
    if denoise:
        if denoise_model_config is not None and os.path.exists(denoise_model_config):
            with open(denoise_model_config, 'r') as f:
                denoise_model_config = json.load(f)
        else:
            denoise_model_config = {}
        denoise_features = denoise_model_config.pop('features', True)
        denoise_optimizer_fn_args = denoise_model_config.pop('optimizer_fn_args', {'lr': 1e-3, 'weight_decay': 1e-2})
        denoise_lr_scheduler = denoise_model_config.pop('denoise_lr_scheduler', None)
        #Todo: put all arguments into denoise_model_config
        denoise_model = None # will be updated for each file
        
    for file in bin_files:
        print(f'Process {file}')
        start_time = time.time()
        # download .bin file if not exists
        if not os.path.exists(f'{data_folder}/{file}.bin'):
            command = ['gsutil', '-m', 'cp', f'{bucket}/{file}.bin', data_folder]
            response = subprocess.run(command, capture_output=True)
            assert response.returncode == 0
        # create save folder if not exists
        save_folder = f'{data_folder}/{result_folder}/{file}'
        if not os.path.exists(save_folder):
            if verbose:
                print(f'Create folder {save_folder}')
            os.makedirs(save_folder)
        # load mat and meta data
        nframe = meta_data[file]['numFramesRequested']
        ncol, nrow = meta_data[file]['movSize']
        blue_light_on_off = np.array(meta_data[file]['blueFrameOnOff']).reshape(-1, 2)
        torch.cuda.empty_cache()
        mat = load_file(f'{data_folder}/{file}.bin', size=(nframe, nrow, ncol), astype='float16' if half_precision else 'float32', device=device)
        # without detrending, only use selected frames to calculate correlation
        submat = torch.stack([mat[i-1:j-1] for i, j in blue_light_on_off], dim=0)
        cor_map, label_image, regions = basic_segmentation(submat, min_pixels=20, select_frames=True, median_detrend=False, fft=False, show=False)
        # save correlation map
        imshow(cor_map, save_file=f'{save_folder}/cor_map.png', title=f'{file}: min_cor={cor_map.min():.2f}, max_cor={cor_map.max():.2f}', 
               display=display)
        np.save(f'{save_folder}/cor_map.npy', cor_map.cpu().numpy())
        # save basic segmentation results
        save_results(mat, segmentation_name='basic_segmentation', softmask=cor_map, label_image=label_image, regions=regions, save_folder=save_folder,
                     file=file, display=display, verbose=verbose)
        if denoise:
            # noise2self
            if verbose:
                print('Denoising')
                now = time.time()
            denoise_loss_history = []
            denoise_save_folder = denoise_model_config.pop('denoise_save_folder', f'{save_folder}/denoise')
            denoised_mat, denoise_model = get_denoised_mat(mat, features=denoise_features, model=denoise_model, save_folder=denoise_save_folder, 
                                                           loss_threshold=denoise_loss_threshold, loss_history=denoise_loss_history, verbose=verbose, 
                                                           optimizer_fn_args=denoise_optimizer_fn_args, lr_scheduler=denoise_lr_scheduler, 
                                                           out_channels=[64, 64, 128], kernel_size_unet=3, ndim=2, frame_depth=4,
                                                           last_out_channels=100, normalize=True, 
                                                           num_epochs=denoise_num_epochs, num_iters=denoise_num_iters, print_every=300, 
                                                           batch_size=denoise_batch_size, batch_size_eval=denoise_batch_size_eval, 
                                                           mask_prob=0.05, frame_weight=None, 
                                                           save_intermediate_results=False, movie_start_idx=250, movie_end_idx=750, fps=60,
                                                           loss_reg_fn=nn.MSELoss(), optimizer_fn=torch.optim.AdamW, 
                                                           window_size_row=None, window_size_col=None, weight=None, return_model=True, device=device)
            save_results(denoised_mat, segmentation_name='basic_segmentation_denoise', softmask=cor_map, label_image=label_image, regions=regions, 
                         save_folder=save_folder, file=file, display=display, verbose=verbose)
            if verbose:
                print(f'Denoising time: {time.time() - now:.2f} s')
        if apply_spectral_clustering:
            if verbose:
                print('Apply spectral clustering')
                now = time.time()
            sel_label_idx = 1
            while sel_label_idx <= label_image.max():
                if verbose:
                    print(sel_label_idx)
                split_clusters(sel_label_idx, mat, label_image, cor_threshold=spectral_cor_threshold, soft_threshold=spectral_soft_threshold, 
                               min_num_pixels=50, max_dist=2, median_detrend=True, apply_fft=True, fft_max_freq=200, verbose=verbose)
                sel_label_idx += 1
            regions = regionprops(label_image.cpu().numpy())
            save_results(mat, segmentation_name='spectral_clustering', softmask=cor_map, label_image=label_image, regions=regions, 
                         save_folder=save_folder, file=file, display=display, verbose=verbose)
            if verbose:
                print(f'Spectral clustering time: {time.time() - now:.2f} s')

        if verbose:
            print('Generating figures and uploading results to google bucket')
        save_figures(segmentation_name='basic_segmentation')
        save_figures(segmentation_name='basic_segmentation_denoise')
        save_figures(segmentation_name='spectral_clustering')
        save_figures(segmentation_name='denoise_spectral_clustering')
        command = ['gsutil', '-m', 'cp', '-r', 
                   save_folder, 
                   f'{bucket}/{result_folder}/{file}']
        response = subprocess.run(command, capture_output=True)
        assert response.returncode == 0
    
        if delete_local_data:
            os.remove(f'{data_folder}/{file}.bin')
            shutil.rmtree(f'{data_folder}/{result_folder}/{file}')
        empty_cache(lambda k, v: isinstance(v, torch.Tensor))
        torch.cuda.empty_cache()
        end_time = time.time()
        print(f'Time spent: {end_time - start_time}')
        
    if delete_local_data:
        shutil.rmtree(data_folder)
コード例 #5
0
def neighbor_cor(mat,
                 neighbors=8,
                 plot=True,
                 choice='max',
                 title='Correlation Map',
                 return_adj_list=False):
    """Assume mat has shape (D, nrow, ncol)
    """
    is_grad_enabled = torch.is_grad_enabled()
    torch.set_grad_enabled(False)
    cor1 = cosine_similarity(mat[:, 1:], mat[:, :-1], dim=0)  # row shift
    cor2 = cosine_similarity(mat[:, :, 1:], mat[:, :, :-1],
                             dim=0)  # column shift
    cor3 = cosine_similarity(mat[:, 1:, 1:], mat[:, :-1, :-1],
                             dim=0)  # diagonal 135
    cor4 = cosine_similarity(mat[:, 1:, :-1], mat[:, :-1, 1:],
                             dim=0)  # diagonal 45
    nrow, ncol = mat.shape[1:]
    if return_adj_list:
        adj_list = np.concatenate([
            get_adj_list(c, nrow, ncol, i + 1)
            for i, c in enumerate([cor1, cor2, cor3, cor4])
        ],
                                  axis=0)
        return adj_list

    cor = mat.new_zeros(mat.shape[1:])
    if choice == 'mean':
        cor[:-1] += cor1
        cor[1:] += cor1
        cor[:, :-1] += cor2
        cor[:, 1:] += cor2
        if neighbors == 4:
            denominators = [4, 3, 2]
        elif neighbors == 8:
            denominators = [8, 5, 3]
            cor[1:, 1:] += cor3
            cor[:-1, :-1] += cor3
            cor[1:, :-1] += cor4
            cor[:-1, 1:] += cor4
        else:
            raise ValueError(f'neighbors={neighbors} is not implemented!')
        cor[1:-1, 1:-1] /= denominators[0]
        cor[0, 1:-1] /= denominators[1]
        cor[-1, 1:-1] /= denominators[1]
        cor[1:-1, 0] /= denominators[1]
        cor[1:-1, -1] /= denominators[1]
        cor[0, 0] /= denominators[2]
        cor[0, -1] /= denominators[2]
        cor[-1, 0] /= denominators[2]
        cor[-1, -1] /= denominators[2]
    elif choice == 'max':
        cor[:-1] = torch.max(cor[:-1], cor1)
        cor[1:] = torch.max(cor[1:], cor1)
        cor[:, :-1] = torch.max(cor[:, :-1], cor2)
        cor[:, 1:] = torch.max(cor[:, 1:], cor2)
        if neighbors == 8:
            cor[1:, 1:] = torch.max(cor[1:, 1:], cor3)
            cor[:-1, :-1] = torch.max(cor[:-1, :-1], cor3)
            cor[1:, :-1] = torch.max(cor[1:, :-1], cor4)
            cor[:-1, 1:] = torch.max(cor[:-1, 1:], cor4)
    else:
        raise ValueError(f'choice = {choice} is not implemented!')
    if plot:
        #         imshow(cor1, title='cor1')
        #         imshow(cor2, title='cor2')
        #         imshow(cor3, title='cor3')
        #         imshow(cor4, title='cor4')
        imshow(mat.mean(0), title='Temporal Mean')
        imshow(cor, title=title)
    torch.set_grad_enabled(is_grad_enabled)
    for k in [
            k for k, v in sorted(locals().items())
            if isinstance(v, torch.Tensor) and k not in ['cor']
    ]:
        del locals()[k]
    torch.cuda.empty_cache()
    return torch.nn.functional.relu(cor, inplace=True)
コード例 #6
0
ファイル: utility.py プロジェクト: BeautyOfWeb/OPP_Analysis
def neighbor_cor(mat,
                 neighbors=8,
                 choice='mean',
                 nonnegative=True,
                 return_adj_list=False,
                 plot=False,
                 title='Correlation Map'):
    """Calculate neighborhood correlation map; deprecated, in favor of get_cor
    Args:
        mat: 3D torch.Tensor with shape (nframe, nrow, ncol)
    
    Returns:
        cor: 2D torch.Tensor, correlation map
    """
    cor1 = cosine_similarity(mat[:, 1:], mat[:, :-1], dim=0)  # row shift
    cor2 = cosine_similarity(mat[:, :, 1:], mat[:, :, :-1],
                             dim=0)  # column shift
    cor3 = cosine_similarity(mat[:, 1:, 1:], mat[:, :-1, :-1],
                             dim=0)  # diagonal 135
    cor4 = cosine_similarity(mat[:, 1:, :-1], mat[:, :-1, 1:],
                             dim=0)  # diagonal 45
    nrow, ncol = mat.shape[1:]
    if return_adj_list:
        adj_list = np.concatenate([
            get_adj_list(c, nrow, ncol, i + 1)
            for i, c in enumerate([cor1, cor2, cor3, cor4])
        ],
                                  axis=0)
        return adj_list
    with torch.no_grad():
        cor = mat.new_zeros(mat.shape[1:])
        if choice == 'mean':
            cor[:-1] += cor1
            cor[1:] += cor1
            cor[:, :-1] += cor2
            cor[:, 1:] += cor2
            if neighbors == 4:
                denominators = [4, 3, 2]
            elif neighbors == 8:
                denominators = [8, 5, 3]
                cor[1:, 1:] += cor3
                cor[:-1, :-1] += cor3
                cor[1:, :-1] += cor4
                cor[:-1, 1:] += cor4
            else:
                raise ValueError(f'neighbors={neighbors} is not implemented!')
            cor[1:-1, 1:-1] /= denominators[0]
            cor[0, 1:-1] /= denominators[1]
            cor[-1, 1:-1] /= denominators[1]
            cor[1:-1, 0] /= denominators[1]
            cor[1:-1, -1] /= denominators[1]
            cor[0, 0] /= denominators[2]
            cor[0, -1] /= denominators[2]
            cor[-1, 0] /= denominators[2]
            cor[-1, -1] /= denominators[2]
        elif choice == 'max':
            cor[:-1] = torch.max(cor[:-1], cor1)
            cor[1:] = torch.max(cor[1:], cor1)
            cor[:, :-1] = torch.max(cor[:, :-1], cor2)
            cor[:, 1:] = torch.max(cor[:, 1:], cor2)
            if neighbors == 8:
                cor[1:, 1:] = torch.max(cor[1:, 1:], cor3)
                cor[:-1, :-1] = torch.max(cor[:-1, :-1], cor3)
                cor[1:, :-1] = torch.max(cor[1:, :-1], cor4)
                cor[:-1, 1:] = torch.max(cor[:-1, 1:], cor4)
        else:
            raise ValueError(f'choice = {choice} is not implemented!')
    if plot:
        imshow(mat.mean(0), title='Temporal Mean')
        imshow(cor, title=title)
    if nonnegative:
        cor = torch.nn.functional.relu(cor, inplace=True)
    for k in [k for k in locals().keys() if k != 'cor']:
        del locals()[k]
    torch.cuda.empty_cache()
    return cor