Example #1
0
def mc_vids(vids_fpath, mc_rigid_template):
    start = time.time()
    # estimated minimum value of the movie to produce an output that is positive
    min_mov = np.array([
        cm.motion_correction.high_pass_filter_space(m_, gSig_filt)
        for m_ in cm.load(vids_fpath[0], subindices=range(400))
    ]).min()
    mc = MotionCorrect(vids_fpath,
                       min_mov,
                       dview=dview,
                       max_shifts=max_shifts,
                       niter_rig=1,
                       splits_rig=splits_rig,
                       num_splits_to_process_rig=None,
                       shifts_opencv=True,
                       nonneg_movie=True,
                       gSig_filt=gSig_filt,
                       border_nan=border_nan,
                       is3D=False)

    mc.motion_correct_rigid(save_movie=(not doPwRigid),
                            template=mc_rigid_template)

    shifts_rig = mc.shifts_rig
    template_rig = mc.total_template_rig

    if doPwRigid:
        mc.motion_correct_pwrigid(save_movie=True, template=template_rig)
        mc.total_template_rig = template_rig

    duration = time.time() - start
    logging.info('Motion correction done in %s', str(duration))
    return mc, duration, shifts_rig
Example #2
0
    def _run_motion_correction(self, file_name, max_shifts, strides, overlaps,
                               upsample_factor_grid, max_deviation_rigid):
        """
        Private function that initiates motion correction from CaImAn package,
        and return the motion corrected file names and their respective pixel
        shifts.
        """

        offset_orig = np.nanmin(self.data_orig[:1000])
        G6Image = MotionCorrect(file_name,
                                offset_orig,
                                max_shifts=max_shifts,
                                niter_rig=1,
                                splits_rig=56,
                                strides=strides,
                                overlaps=overlaps,
                                splits_els=56,
                                upsample_factor_grid=upsample_factor_grid,
                                shifts_opencv=True,
                                max_deviation_rigid=max_deviation_rigid,
                                nonneg_movie=True)

        G6Image.motion_correct_rigid(save_movie=True)

        G6Image.motion_correct_pwrigid(save_movie=True,
                                       template=G6Image.total_template_rig)
        name_rig = G6Image.fname_tot_rig
        name_pwrig = G6Image.fname_tot_els
        shifts_rig = G6Image.shifts_rig
        x_shifts_pwrig = G6Image.x_shifts_els
        y_shifts_pwrig = G6Image.y_shifts_els
        template_shape = G6Image.total_template_els.shape

        return name_rig, name_pwrig, shifts_rig, x_shifts_pwrig, y_shifts_pwrig, template_shape
Example #3
0
    def motion_correct_pwrigid(self, fname):
        dview = None
        try:
            c, dview, n_processes = cm.cluster.setup_cluster(
                backend='local', n_processes=None, single_thread=False)

            niter_rig = 1  # number of iterations for rigid motion correction
            max_shifts = self.get_dict_param(
                'max_shifts_pwrigid', 'tuple_int')  # maximum allow rigid shift
            # for parallelization split the movies in  num_splits chuncks across time
            splits_rig = self.get_dict_param('splits_rig', 'single_int')
            # start a new patch for pw-rigid motion correction every x pixels
            strides = self.get_dict_param('strides', 'tuple_int')
            # overlap between pathes (size of patch strides+overlaps)
            overlaps = self.get_dict_param('overlaps', 'tuple_int')
            # for parallelization split the movies in  num_splits chuncks across time
            splits_els = self.get_dict_param('splits_els', 'single_int')

            upsample_factor_grid = self.get_dict_param(
                'upsample_factor_grid', 'single_int'
            )  # upsample factor to avoid smearing when merging patches
            # maximum deviation allowed for patch with respect to rigid shifts
            max_deviation_rigid = self.get_dict_param('max_deviation_rigid',
                                                      'single_int')
            # first we create a motion correction object with the parameters specified
            min_mov = cm.load(fname[0], subindices=range(200)).min()
            # this will be subtracted from the movie to make it non-negative

            print(
                str([
                    max_shifts, splits_rig, strides, overlaps, splits_els,
                    upsample_factor_grid, max_deviation_rigid, min_mov
                ]))
            mc = MotionCorrect(fname,
                               min_mov,
                               dview=dview,
                               max_shifts=max_shifts,
                               niter_rig=niter_rig,
                               splits_rig=splits_rig,
                               strides=strides,
                               overlaps=overlaps,
                               splits_els=splits_els,
                               border_nan='copy',
                               upsample_factor_grid=upsample_factor_grid,
                               max_deviation_rigid=max_deviation_rigid,
                               shifts_opencv=True,
                               nonneg_movie=True)

            mc.motion_correct_pwrigid(save_movie=True)
            self.motion_correct = mc

        except Exception as e:
            raise e
        finally:
            cm.cluster.stop_server(dview=dview)
def motion_correct(video_path, max_shift, patch_stride, patch_overlap, use_multiprocessing=True):
    full_video_path = video_path

    directory = os.path.dirname(full_video_path)
    filename  = os.path.basename(full_video_path)

    memmap_video = tifffile.memmap(video_path)

    if use_multiprocessing:
        if os.name == 'nt':
            backend = 'multiprocessing'
        else:
            backend = 'ipyparallel'

        # Create the cluster
        cm.stop_server()
        c, dview, n_processes = cm.cluster.setup_cluster(backend=backend, n_processes=None, single_thread=False)
    else:
        dview = None

    z_range = list(range(memmap_video.shape[1]))

    new_video_path = os.path.join(directory, "mc_video_temp.tif")

    shutil.copyfile(video_path, new_video_path)

    mc_video = tifffile.memmap(new_video_path).astype(np.uint16)

    mc_borders = [ None for z in z_range ]

    counter = 0

    for z in z_range:
        print("Motion correcting plane z={}...".format(z))
        video_path = os.path.join(directory, os.path.splitext(filename)[0] + "_z_{}_temp.tif".format(z))
        tifffile.imsave(video_path, memmap_video[:, z, :, :])

        mc_video[:, z, :, :] *= 0

        # --- PARAMETERS --- #

        params_movie = {'fname': video_path,
                        'max_shifts': (max_shift, max_shift),  # maximum allow rigid shift (2,2)
                        'niter_rig': 3,
                        'splits_rig': 1,  # for parallelization split the movies in  num_splits chuncks across time
                        'num_splits_to_process_rig': None,  # if none all the splits are processed and the movie is saved
                        'strides': (patch_stride, patch_stride),  # intervals at which patches are laid out for motion correction
                        'overlaps': (patch_overlap, patch_overlap),  # overlap between pathes (size of patch strides+overlaps)
                        'splits_els': 1,  # for parallelization split the movies in  num_splits chuncks across time
                        'num_splits_to_process_els': [None],  # if none all the splits are processed and the movie is saved
                        'upsample_factor_grid': 4,  # upsample factor to avoid smearing when merging patches
                        'max_deviation_rigid': 3,  # maximum deviation allowed for patch with respect to rigid shift         
                        }

        # load movie (in memory!)
        fname = params_movie['fname']
        niter_rig = params_movie['niter_rig']
        # maximum allow rigid shift
        max_shifts = params_movie['max_shifts']  
        # for parallelization split the movies in  num_splits chuncks across time
        splits_rig = params_movie['splits_rig']  
        # if none all the splits are processed and the movie is saved
        num_splits_to_process_rig = params_movie['num_splits_to_process_rig']
        # intervals at which patches are laid out for motion correction
        strides = params_movie['strides']
        # overlap between pathes (size of patch strides+overlaps)
        overlaps = params_movie['overlaps']
        # for parallelization split the movies in  num_splits chuncks across time
        splits_els = params_movie['splits_els'] 
        # if none all the splits are processed and the movie is saved
        num_splits_to_process_els = params_movie['num_splits_to_process_els']
        # upsample factor to avoid smearing when merging patches
        upsample_factor_grid = params_movie['upsample_factor_grid'] 
        # maximum deviation allowed for patch with respect to rigid
        # shift
        max_deviation_rigid = params_movie['max_deviation_rigid']

        # --- RIGID MOTION CORRECTION --- #

        # Load the original movie
        m_orig = tifffile.memmap(fname)
        # m_orig = cm.load(fname)
        min_mov = np.min(m_orig) # movie must be mostly positive for this to work

        offset_mov = -min_mov

        # Create motion correction object
        mc = MotionCorrect(fname, min_mov,
                           dview=dview, max_shifts=max_shifts, niter_rig=niter_rig, splits_rig=splits_rig, 
                           num_splits_to_process_rig=num_splits_to_process_rig, 
                        strides= strides, overlaps= overlaps, splits_els=splits_els,
                        num_splits_to_process_els=num_splits_to_process_els, 
                        upsample_factor_grid=upsample_factor_grid, max_deviation_rigid=max_deviation_rigid, 
                        shifts_opencv = True, nonneg_movie = True, border_nan='min')

        # Do rigid motion correction
        mc.motion_correct_rigid(save_movie=False)

        # --- ELASTIC MOTION CORRECTION --- #

        # Do elastic motion correction
        mc.motion_correct_pwrigid(save_movie=True, template=mc.total_template_rig, show_template=False)

        # # Save elastic shift border
        bord_px_els = np.ceil(np.maximum(np.max(np.abs(mc.x_shifts_els)),
                                 np.max(np.abs(mc.y_shifts_els)))).astype(np.int)  
        # np.savez(mc.fname_tot_els + "_bord_px_els.npz", bord_px_els)

        fnames = mc.fname_tot_els   # name of the pw-rigidly corrected file.
        border_to_0 = bord_px_els     # number of pixels to exclude
        fname_new = cm.save_memmap(fnames, base_name='memmap_z_{}'.format(z), order = 'C',
                                   border_to_0 = bord_px_els) # exclude borders

        # now load the file
        Yr, dims, T = cm.load_memmap(fname_new)
        d1, d2 = dims
        images = np.reshape(Yr.T, [T] + list(dims), order='F') 

        mc_borders[z] = bord_px_els

        # images += np.amin(images)

        # print(np.amax(images))
        # print(np.amin(images))
        # print(type(images))

        mc_video[:, z, :, :] = (images - np.amin(images)).astype(np.uint16)

        del m_orig
        os.remove(video_path)

        try:
            os.remove(mc.fname_tot_rig)
            os.remove(mc.fname_tot_els)
        except:
            pass

        counter += 1

    if use_multiprocessing:
        if backend == 'multiprocessing':
            dview.close()
        else:
            try:
                dview.terminate()
            except:
                dview.shutdown()
        cm.stop_server()

    mmap_files = glob.glob(os.path.join(directory, '*.mmap'))
    for mmap_file in mmap_files:
        try:
            os.remove(mmap_file)
        except:
            pass

    log_files = glob.glob('Yr*_LOG_*')
    for log_file in log_files:
        os.remove(log_file)

    return mc_video, new_video_path, mc_borders
Example #5
0
                   min_mov,
                   dview=dview,
                   max_shifts=max_shifts,
                   niter_rig=niter_rig,
                   splits_rig=splits_rig,
                   strides=strides,
                   overlaps=overlaps,
                   splits_els=splits_els,
                   upsample_factor_grid=upsample_factor_grid,
                   max_deviation_rigid=max_deviation_rigid,
                   shifts_opencv=True,
                   nonneg_movie=True)
# note that the file is not loaded in memory

#%% Run piecewise-rigid motion correction using NoRMCorre
mc.motion_correct_pwrigid(save_movie=True)
m_els = cm.load(mc.fname_tot_els)
bord_px_els = np.ceil(
    np.maximum(np.max(np.abs(mc.x_shifts_els)),
               np.max(np.abs(mc.y_shifts_els)))).astype(np.int)
# maximum shift to be used for trimming against NaNs
#%% compare with original movie
cm.concatenate([
    m_orig.resize(1, 1, downsample_ratio) + offset_mov,
    m_els.resize(1, 1, downsample_ratio)
],
               axis=2).play(fr=60, gain=15, magnification=2,
                            offset=0)  # press q to exit

#%% MEMORY MAPPING
# memory map the file in order 'C'
Example #6
0
                   dview=dview, max_shifts=max_shifts, niter_rig=niter_rig,
                   splits_rig=splits_rig, 
                   strides= strides, overlaps= overlaps, splits_els=splits_els,
                   upsample_factor_grid=upsample_factor_grid,
                   max_deviation_rigid=max_deviation_rigid, 
                   shifts_opencv = True, nonneg_movie = True)
# note that the file is not loaded in memory
#%%
mc.motion_correct_rigid(save_movie=True)
m_els = cm.load(mc.fname_tot_rig[0])
bord_px_els = np.ceil(np.max(np.abs(mc.shifts_rig))).astype(np.int)  
fname_cor = mc.fname_tot_rig
#%% Run piecewise-rigid motion correction using NoRMCorre
non_rigid = False
if non_rigid:
    mc.motion_correct_pwrigid(save_movie=True)
    m_els = cm.load(mc.fname_tot_els)
    bord_px_els = np.ceil(np.maximum(np.max(np.abs(mc.x_shifts_els)),
                                 np.max(np.abs(mc.y_shifts_els)))).astype(np.int)  
    fname_cor = mc.fname_tot_els
    # maximum shift to be used for trimming against NaNs
#%% compare with original movie
cm.concatenate([m_orig.resize(1, 1, downsample_ratio)+offset_mov,
                m_els.resize(1, 1, downsample_ratio)], 
               axis=2).play(fr=60, gain=5, magnification=1, offset=0)  # press q to exit

#%% MEMORY MAPPING
# memory map the file in order 'C'
fnames = fname_cor   # name of the pw-rigidly corrected file.
border_to_0 = bord_px_els     # number of pixels to exclude
fname_new = cm.save_memmap(fnames, base_name='memmap_', order = 'C',
Example #7
0
def run_motion_correction(cropping_file, dview):
    """
    This is the function for motion correction. Its goal is to take in a decoded and
    cropped .tif file, perform motion correction, and save the result as a .mmap file.

    This function is only runnable on the cn76 server because it requires parallel processing.

    Args:
        cropping_file: tif file after cropping
        dview: cluster

    Returns:
        row: pd.DataFrame object
            The row corresponding to the motion corrected analysis state.
    """
    # Get output file paths

    data_dir = os.environ['DATA_DIR_LOCAL'] + 'data/interim/motion_correction/'
    sql = "SELECT mouse,session,trial,is_rest,decoding_v,cropping_v,motion_correction_v,input,home_path,decoding_main FROM Analysis WHERE cropping_main=? ORDER BY motion_correction_v"
    val = [
        cropping_file,
    ]
    cursor.execute(sql, val)
    result = cursor.fetchall()
    data = []
    inter = []
    for x in result:
        inter = x
    for y in inter:
        data.append(y)

    # Update the database

    if data[6] == 0:
        data[6] = 1
        file_name = f"mouse_{data[0]}_session_{data[1]}_trial_{data[2]}.{data[3]}.v{data[4]}.{data[5]}.{data[6]}"
        output_meta_pkl_file_path = f'meta/metrics/{file_name}.pkl'
        sql1 = "UPDATE Analysis SET motion_correction_meta=?,motion_correction_v=? WHERE cropping_main=? "
        val1 = [output_meta_pkl_file_path, data[6], cropping_file]
        cursor.execute(sql1, val1)

    else:
        data[6] += 1
        file_name = f"mouse_{data[0]}_session_{data[1]}_trial_{data[2]}.{data[3]}.v{data[4]}.{data[5]}.{data[6]}"
        output_meta_pkl_file_path = f'meta/metrics/{file_name}.pkl'
        sql2 = "INSERT INTO Analysis (motion_correction_meta,motion_correction_v) VALUES (?,?)"
        val2 = [output_meta_pkl_file_path, data[6]]
        cursor.execute(sql2, val2)
        database.commit()
        sql3 = "UPDATE Analysis SET decoding_main=?,decoding_v=?,mouse=?,session=?,trial=?,is_rest=?,input=?,home_path=?,cropping_v=?,cropping_main=? WHERE motion_correction_meta=? AND motion_correction_v=?"
        val3 = [
            data[9], data[4], data[0], data[1], data[2], data[3], data[7],
            data[8], data[5], cropping_file, output_meta_pkl_file_path, data[6]
        ]
        cursor.execute(sql3, val3)
    database.commit()
    output_meta_pkl_file_path_full = data_dir + output_meta_pkl_file_path

    # Calculate movie minimum to subtract from movie
    cropping_file_full = os.environ['DATA_DIR_LOCAL'] + cropping_file
    min_mov = np.min(cm.load(cropping_file_full))

    # Apply the parameters to the CaImAn algorithm

    sql5 = "SELECT motion_correct,pw_rigid,save_movie_rig,gSig_filt,max_shifts,niter_rig,strides,overlaps,upsample_factor_grid,num_frames_split,max_deviation_rigid,shifts_opencv,use_conda,nonneg_movie, border_nan  FROM Analysis WHERE cropping_main=? "
    val5 = [
        cropping_file,
    ]
    cursor.execute(sql5, val5)
    myresult = cursor.fetchall()
    para = []
    aux = []
    for x in myresult:
        aux = x
    for y in aux:
        para.append(y)
    parameters = {
        'motion_correct': para[0],
        'pw_rigid': para[1],
        'save_movie_rig': para[2],
        'gSig_filt': (para[3], para[3]),
        'max_shifts': (para[4], para[4]),
        'niter_rig': para[5],
        'strides': (para[6], para[6]),
        'overlaps': (para[7], para[7]),
        'upsample_factor_grid': para[8],
        'num_frames_split': para[9],
        'max_deviation_rigid': para[10],
        'shifts_opencv': para[11],
        'use_cuda': para[12],
        'nonneg_movie': para[13],
        'border_nan': para[14]
    }
    caiman_parameters = parameters.copy()
    caiman_parameters['min_mov'] = min_mov
    opts = params.CNMFParams(params_dict=caiman_parameters)

    # Rigid motion correction (in both cases)

    logging.info('Performing rigid motion correction')
    t0 = datetime.datetime.today()

    # Create a MotionCorrect object

    mc = MotionCorrect([cropping_file_full],
                       dview=dview,
                       **opts.get_group('motion'))

    # Perform rigid motion correction

    mc.motion_correct_rigid(save_movie=parameters['save_movie_rig'],
                            template=None)
    dt = int(
        (datetime.datetime.today() - t0).seconds / 60)  # timedelta in minutes
    logging.info(f' Rigid motion correction finished. dt = {dt} min')

    # Obtain template, rigid shifts and border pixels

    total_template_rig = mc.total_template_rig
    shifts_rig = mc.shifts_rig

    # Save template, rigid shifts and border pixels in a dictionary

    meta_pkl_dict = {
        'rigid': {
            'template': total_template_rig,
            'shifts': shifts_rig,
        }
    }
    sql = "UPDATE Analysis SET duration_rigid=? WHERE motion_correction_meta=? AND motion_correction_v=? "
    val = [dt, output_meta_pkl_file_path, data[6]]
    cursor.execute(sql, val)

    if parameters['save_movie_rig'] == 1:
        # Load the movie saved by CaImAn, which is in the wrong
        # directory and is not yet cropped

        logging.info(f' Loading rigid movie for cropping')
        m_rig = cm.load(mc.fname_tot_rig[0])
        logging.info(f' Loaded rigid movie for cropping')

        # Get the cropping points determined by the maximal rigid shifts

        x_, _x, y_, _y = get_crop_from_rigid_shifts(shifts_rig)

        # Crop the movie

        logging.info(
            f' Cropping and saving rigid movie with cropping points: [x_, _x, y_, _y] = {[x_, _x, y_, _y]}'
        )
        m_rig = m_rig.crop(x_, _x, y_, _y, 0, 0)
        meta_pkl_dict['rigid']['cropping_points'] = [x_, _x, y_, _y]
        sql = "UPDATE Analysis SET motion_correction_cropping_points_x1=?,motion_correction_cropping_points_x2=?,motion_correction_cropping_points_y1=?,motion_correction_cropping_points_y2=? WHERE motion_correction_meta=? AND motion_correction_v=? "
        val = [x_, _x, y_, _y, output_meta_pkl_file_path, data[6]]
        cursor.execute(sql, val)

        # Save the movie

        rig_role = 'alternate' if parameters['pw_rigid'] else 'main'
        fname_tot_rig = m_rig.save(data_dir + rig_role + '/' + file_name +
                                   '_rig' + '.mmap',
                                   order='C')
        logging.info(f' Cropped and saved rigid movie as {fname_tot_rig}')

        # Remove the remaining non-cropped movie

        os.remove(mc.fname_tot_rig[0])

        sql = "UPDATE Analysis SET motion_correction_rig_role=? WHERE motion_correction_meta=? AND motion_correction_v=? "
        val = [fname_tot_rig, output_meta_pkl_file_path, data[6]]
        cursor.execute(sql, val)
        database.commit()

    # If specified in the parameters, apply piecewise-rigid motion correction
    if parameters['pw_rigid'] == 1:
        logging.info(f' Performing piecewise-rigid motion correction')
        t0 = datetime.datetime.today()
        # Perform non-rigid (piecewise rigid) motion correction. Use the rigid result as a template.
        mc.motion_correct_pwrigid(save_movie=True, template=total_template_rig)
        # Obtain template and filename
        total_template_els = mc.total_template_els
        fname_tot_els = mc.fname_tot_els[0]

        dt = int((datetime.datetime.today() - t0).seconds /
                 60)  # timedelta in minutes
        meta_pkl_dict['pw_rigid'] = {
            'template': total_template_els,
            'x_shifts': mc.x_shifts_els,
            'y_shifts': mc.
            y_shifts_els  # removed them initially because they take up space probably
        }

        logging.info(
            f' Piecewise-rigid motion correction finished. dt = {dt} min')

        # Load the movie saved by CaImAn, which is in the wrong
        # directory and is not yet cropped

        logging.info(f' Loading pw-rigid movie for cropping')
        m_els = cm.load(fname_tot_els)
        logging.info(f' Loaded pw-rigid movie for cropping')

        # Get the cropping points determined by the maximal rigid shifts

        x_, _x, y_, _y = get_crop_from_pw_rigid_shifts(
            np.array(mc.x_shifts_els), np.array(mc.y_shifts_els))
        # Crop the movie

        logging.info(
            f' Cropping and saving pw-rigid movie with cropping points: [x_, _x, y_, _y] = {[x_, _x, y_, _y]}'
        )
        m_els = m_els.crop(x_, _x, y_, _y, 0, 0)
        meta_pkl_dict['pw_rigid']['cropping_points'] = [x_, _x, y_, _y]

        # Save the movie

        fname_tot_els = m_els.save(data_dir + 'main/' + file_name + '_els' +
                                   '.mmap',
                                   order='C')
        logging.info(f'Cropped and saved rigid movie as {fname_tot_els}')

        # Remove the remaining non-cropped movie

        os.remove(mc.fname_tot_els[0])

        sql = "UPDATE Analysis SET  motion_correction_main=?, motion_correction_cropping_points_x1=?,motion_correction_cropping_points_x2=?,motion_correction_cropping_points_y1=?,motion_correction_cropping_points_y2=?,duration_pw_rigid=? WHERE motion_correction_meta=? AND motion_correction_v=? "
        val = [
            fname_tot_els, x_, _x, y_, _y, dt, output_meta_pkl_file_path,
            data[6]
        ]
        cursor.execute(sql, val)
        database.commit()

    # Write meta results dictionary to the pkl file

    pkl_file = open(output_meta_pkl_file_path_full, 'wb')
    pickle.dump(meta_pkl_dict, pkl_file)
    pkl_file.close()

    return fname_tot_els, data[6]
Example #8
0
def run_single(batch_dir, UUID, output):
    file_path = os.path.join(batch_dir, UUID)

    n_processes = os.environ['_MESMERIZE_N_THREADS']
    n_processes = int(n_processes)

    c, dview, n_processes = cm.cluster.setup_cluster(backend='local',
                                                     n_processes=n_processes,
                                                     single_thread=False,
                                                     ignore_preexisting=True)

    fname = [file_path + '_input.tiff']
    # TODO: Should just unpack the input params as kwargs
    input_params: dict = pickle.load(open(file_path + '.params', 'rb'))
    mc_kwargs = input_params['mc_kwargs']

    splits_rig = n_processes

    splits_els = n_processes

    if os.environ['_MESMERIZE_USE_CUDA'] == 'True':
        USE_CUDA = True
    else:
        USE_CUDA = False

    min_mov = cm.load(fname[0], subindices=range(200)).min()

    mc = MotionCorrect(fname[0],
                       min_mov,
                       dview=dview,
                       splits_rig=splits_rig,
                       splits_els=splits_els,
                       shifts_opencv=True,
                       nonneg_movie=True,
                       use_cuda=USE_CUDA,
                       **mc_kwargs)

    if 'template' in input_params.keys():
        template = input_params['template']
    else:
        template = None

    mc.motion_correct_pwrigid(save_movie=True, template=template)
    m_els = cm.load(mc.fname_tot_els)
    bord_px_els = np.ceil(
        np.maximum(np.max(np.abs(mc.x_shifts_els)),
                   np.max(np.abs(mc.y_shifts_els)))).astype(np.int)

    m_els -= np.nanmin(m_els)

    if input_params['output_bit_depth'] == 'Do not convert':
        pass
    elif input_params['output_bit_depth'] == '8':
        m_els = m_els.astype(np.uint8, copy=False)
    elif input_params['output_bit_depth'] == '16':
        m_els = m_els.astype(np.uint16)

    img_out_path = os.path.join(batch_dir, f'{UUID}_mc.tiff')
    tifffile.imsave(img_out_path,
                    m_els,
                    bigtiff=True,
                    imagej=False,
                    compress=1)
    output['output_files'] = [UUID + '_mc.tiff']

    output.update({'status': 1, 'bord_px': int(bord_px_els)})

    if not input_params.get('keep_memmap', False):
        for mf in glob(os.path.join(batch_dir, UUID + '*.mmap')):
            try:
                os.remove(mf)
            except:
                pass

    dview.terminate()

    return output
Example #9
0
def run_multi(batch_dir, UUID, output):
    file_path = os.path.join(batch_dir, UUID)

    n_processes = os.environ['_MESMERIZE_N_THREADS']
    n_processes = int(n_processes)

    filename = [file_path + '_input.tiff']

    seq = tifffile.TiffFile(filename[0]).asarray()
    seq_shape = seq.shape

    # assume default tzxy
    for z in range(seq.shape[1]):
        tifffile.imsave(f'{file_path}_z{z}.tiff', seq[:, z, :, :])

    del seq

    print("******** Creating process pool ********")
    c, dview, n_processes = cm.cluster.setup_cluster(backend='local',
                                                     n_processes=n_processes,
                                                     single_thread=False,
                                                     ignore_preexisting=True)

    # TODO: Should just unpack the input params as kwargs
    input_params = pickle.load(open(file_path + '.params', 'rb'))
    mc_kwargs = input_params['mc_kwargs']

    splits_rig = n_processes
    splits_els = n_processes

    if os.environ['_MESMERIZE_USE_CUDA'] == 'True':
        USE_CUDA = True
    else:
        USE_CUDA = False

    output_files = []
    for z in range(seq_shape[1]):
        print(f"Plane {z} / {seq_shape[1]}")
        filename = [f'{file_path}_z{z}.tiff']
        print('Creating memmap')

        min_mov = cm.load(filename[0], subindices=range(200)).min()

        mc = MotionCorrect(filename[0],
                           min_mov,
                           dview=dview,
                           splits_rig=splits_rig,
                           splits_els=splits_els,
                           shifts_opencv=True,
                           nonneg_movie=True,
                           use_cuda=USE_CUDA,
                           **mc_kwargs)

        mc.motion_correct_pwrigid(save_movie=True)
        m_els = cm.load(mc.fname_tot_els)
        bord_px_els = np.ceil(
            np.maximum(np.max(np.abs(mc.x_shifts_els)),
                       np.max(np.abs(mc.y_shifts_els)))).astype(np.int)

        m_els -= np.nanmin(m_els)

        if input_params['output_bit_depth'] == 'Do not convert':
            pass
        elif input_params['output_bit_depth'] == '8':
            m_els = m_els.astype(np.uint8, copy=False)
        elif input_params['output_bit_depth'] == '16':
            m_els = m_els.astype(np.uint16)

        if z == 0:
            mc_out = np.zeros(seq_shape, dtype=m_els.dtype)

        mc_out[:, z, :, :] = m_els

    img_out_path = os.path.join(batch_dir, f'{UUID}_mc.tiff')
    tifffile.imsave(img_out_path,
                    mc_out,
                    bigtiff=True,
                    imagej=False,
                    compress=1)
    output['output_files'] = [f'{UUID}_mc.tiff']

    output.update({
        'status': 1,
        'bord_px': int(bord_px_els),
    })

    for mf in glob(os.path.join(batch_dir, UUID + '*.mmap')):
        try:
            os.remove(mf)
        except:
            pass

    dview.terminate()

    return output
Example #10
0
def main():
    pass # For compatibility between running under Spyder and the CLI

#%% First setup some parameters

    # num processes
    n_proc = 12
    # dataset dependent parameters
    fr = 30                             # imaging rate in frames per second
    decay_time = 0.4                    # length of a typical transient in seconds

    # motion correction parameters
    niter_rig = 1               # number of iterations for rigid motion correction
    max_shifts = (6, 6)         # maximum allow rigid shift
    # for parallelization split the movies in  num_splits chuncks across time
    splits_rig = 56
    # start a new patch for pw-rigid motion correction every x pixels
    strides = (48, 48)
    # overlap between pathes (size of patch strides+overlaps)
    overlaps = (24, 24)
    # for parallelization split the movies in  num_splits chuncks across time
    splits_els = 56
    upsample_factor_grid = 4    # upsample factor to avoid smearing when merging patches
    # maximum deviation allowed for patch with respect to rigid shifts
    max_deviation_rigid = 3

    # parameters for source extraction and deconvolution
    p = 1                       # order of the autoregressive system
    gnb = 2                     # number of global background components
    merge_thresh = 0.8          # merging threshold, max correlation allowed
    # half-size of the patches in pixels. e.g., if rf=25, patches are 50x50
    rf = 15
    stride_cnmf = 6             # amount of overlap between the patches in pixels
    K = 4                       # number of components per patch
    gSig = [4, 4]               # expected half size of neurons
    # initialization method (if analyzing dendritic data using 'sparse_nmf')
    init_method = 'greedy_roi'
    is_dendrites = False        # flag for analyzing dendritic data
    # sparsity penalty for dendritic data analysis through sparse NMF
    alpha_snmf = None

    # parameters for component evaluation
    min_SNR = 2.5               # signal to noise ratio for accepting a component
    rval_thr = 0.8              # space correlation threshold for accepting a component
    cnn_thr = 0.8               # threshold for CNN based classifier


    
#%% start a cluster for parallel processing
    c, dview, n_processes = cm.cluster.setup_cluster(
        backend='local', n_processes=n_proc, single_thread=False)
    
    print('Parallel processing initialized.')
    print('Beginning motion correction')

#%%% MOTION CORRECTION

    t_ms = time.time()
    # first we create a motion correction object with the parameters specified
    min_mov = cm.load(fname[0], subindices=range(200)).min()
    # this will be subtracted from the movie to make it non-negative
    if not nomc:
      mc = MotionCorrect(fname[0], min_mov,
                         dview=dview, max_shifts=max_shifts, niter_rig=niter_rig,splits_rig=splits_rig,strides=strides, overlaps=overlaps, splits_els=splits_els,upsample_factor_grid=upsample_factor_grid,max_deviation_rigid=max_deviation_rigid,shifts_opencv=True, nonneg_movie=True, use_cuda=use_cuda)
		# note that the file is not loaded in memory

		#%% Run piecewise-rigid motion correction using NoRMCorre
      mc.motion_correct_pwrigid(save_movie=True)
      m_els = cm.load(mc.fname_tot_els)
      bord_px_els = np.ceil(np.maximum(np.max(np.abs(mc.x_shifts_els)),np.max(np.abs(mc.y_shifts_els)))).astype(np.int)
      t_mf = time.time()
      print('Motion correction complete in ', int(t_mf - t_ms),' seconds')
		# maximum shift to be used for trimming against NaNs
		#%% compare with original movie
    

#%% MEMORY MAPPING
    # memory map the file in order 'C'
      fnames = mc.fname_tot_els   # name of the pw-rigidly corrected file.
      border_to_0 = bord_px_els     # number of pixels to exclude
      fname_new = cm.save_memmap(fnames, base_name='memmap_', order='C',
			       border_to_0=bord_px_els)  # exclude borders


    bord_px_els = 5
    # now load the file
    Yr, dims, T = cm.load_memmap(fname_new)
    d1, d2 = dims
    images = np.reshape(Yr.T, [T] + list(dims), order='F')
    # load frames in python format (T x X x Y)

#%% restart cluster to clean up memory
    cm.stop_server(dview=dview)
    c, dview, n_processes = cm.cluster.setup_cluster(
        backend='local', n_processes=n_proc, single_thread=False)

#%% RUN CNMF ON PATCHES

    # First extract spatial and temporal components on patches and combine them
    # for this step deconvolution is turned off (p=0)
    print('Beginning initial CNMF fit')
    
    t1 = time.time()

    cnm = cnmf.CNMF(n_processes=n_proc, k=K, gSig=gSig, merge_thresh=merge_thresh,
                    p=0, dview=dview, rf=rf, stride=stride_cnmf, memory_fact=1,
                    method_init=init_method, alpha_snmf=alpha_snmf,
                    only_init_patch=False, gnb=gnb, border_pix=bord_px_els)
    cnm = cnm.fit(images)
    
    t2 = time.time()
    print('Initial CNMF fit complete in ', int(t2-t1), 'seconds.')
    

    Cn = cm.local_correlations(images.transpose(1, 2, 0))
    Cn[np.isnan(Cn)] = 0

#%% COMPONENT EVALUATION
    # the components are evaluated in three ways:
    #   a) the shape of each component must be correlated with the data
    #   b) a minimum peak SNR is required over the length of a transient
    #   c) each shape passes a CNN based classifier
    t3 = time.time()
    print('Estimating component quality')
    
    idx_components, idx_components_bad, SNR_comp, r_values, cnn_preds = \
        estimate_components_quality_auto(images, cnm.A, cnm.C, cnm.b, cnm.f,
                                         cnm.YrA, fr, decay_time, gSig, dims,
                                         dview=dview, min_SNR=min_SNR,
                                         r_values_min=rval_thr, use_cnn=False,
                                         thresh_cnn_min=cnn_thr)
    t4 = time.time()
    print('Component quality estimation complete in ', int(t4-t3),' seconds')


#%% RE-RUN seeded CNMF on accepted patches to refine and perform deconvolution
    print('Re-running CNMF to refine and deconvolve')
    t5 = time.time()
    
    A_in, C_in, b_in, f_in = cnm.A[:,
                                   idx_components], cnm.C[idx_components], cnm.b, cnm.f
    cnm2 = cnmf.CNMF(n_processes=n_proc, k=A_in.shape[-1], gSig=gSig, p=p, dview=dview,
                     merge_thresh=merge_thresh, Ain=A_in, Cin=C_in, b_in=b_in,
                     f_in=f_in, rf=None, stride=None, gnb=gnb,
                     method_deconvolution='oasis', check_nan=True)

    cnm2 = cnm2.fit(images)
    t6 = time.time()
    print('CNMF re-run complete in ', int(t6-t5), ' seconds')
#%% Extract DF/F values

#    F_dff = detrend_df_f(cnm2.A, cnm2.b, cnm2.C, cnm2.f, YrA=cnm2.YrA,
#                         quantileMin=8, frames_window=250)


    save_results = True
    if save_results:
        if os.path.isdir(infile[0]):
          outfile = os.path.join(infile[0],'caiman_output.npz')
        else:
          outfile = os.path.join(os.path.split(infile[0])[0],'caiman_output.npz')
        np.savez_compressed(outfile,Cn=Cn, A=cnm2.A.todense(), C=cnm2.C,b=cnm2.b, f=cnm2.f, YrA=cnm2.YrA, d1=d1, d2=d2,idx_components=idx_components, idx_components_bad=idx_components_bad)

    
#%% STOP CLUSTER and clean up log files
    cm.stop_server(dview=dview)
def run(batch_dir: str, UUID: str):
    start_time = time()

    output = {'status': 0, 'output_info': ''}
    file_path = batch_dir + '/' + UUID
    n_processes = os.environ['_MESMERIZE_N_THREADS']
    n_processes = int(n_processes)

    c, dview, n_processes = cm.cluster.setup_cluster(
        backend='local',  # use this one
        n_processes=n_processes,
        # number of process to use, if you go out of memory try to reduce this one
        single_thread=False)

    try:
        fname = [file_path + '.tiff']
        input_params = pickle.load(open(file_path + '.params', 'rb'))
        # TODO: Should just unpack the input params as kwargs
        niter_rig = input_params['iters_rigid']
        max_shifts = (input_params['max_shifts_x'],
                      input_params['max_shifts_y'])
        splits_rig = n_processes

        strides = (input_params['strides'], input_params['strides'])
        overlaps = (input_params['overlaps'], input_params['overlaps'])
        splits_els = n_processes
        upsample_factor_grid = input_params['upsample']
        max_deviation_rigid = input_params['max_dev']

        if 'gSig_filt' in input_params.keys():
            gSig_filt = input_params['gSig_filt']
        else:
            gSig_filt = None

        if os.environ['_MESMERIZE_USE_CUDA'] == 'True':
            USE_CUDA = True
        else:
            USE_CUDA = False

        min_mov = cm.load(fname[0], subindices=range(200)).min()

        mc = MotionCorrect(fname[0],
                           min_mov,
                           dview=dview,
                           max_shifts=max_shifts,
                           niter_rig=niter_rig,
                           splits_rig=splits_rig,
                           strides=strides,
                           overlaps=overlaps,
                           splits_els=splits_els,
                           upsample_factor_grid=upsample_factor_grid,
                           max_deviation_rigid=max_deviation_rigid,
                           shifts_opencv=True,
                           nonneg_movie=True,
                           use_cuda=USE_CUDA,
                           gSig_filt=gSig_filt)

        mc.motion_correct_pwrigid(save_movie=True)
        m_els = cm.load(mc.fname_tot_els)
        bord_px_els = np.ceil(
            np.maximum(np.max(np.abs(mc.x_shifts_els)),
                       np.max(np.abs(mc.y_shifts_els)))).astype(np.int)

        # p = pickle.load(open(UUID + '_workEnv.pik', 'rb'))
        # if p['imdata']['meta']['origin'] == 'mes':
        #     if p['imdata']['meta']['orig_meta']['DataType'] == 'uint16':
        #         pass
        #         # lut = BitDepthConverter.create_lut([np.nanmin(m_els), np.nanmax(m_els)], source=16, out=8)

        #
        # else:
        m_els -= np.nanmin(m_els)

        if input_params['output_bit_depth'] == 'Do not convert':
            pass
        elif input_params['output_bit_depth'] == '8':
            m_els = m_els.astype(np.uint8, copy=False)
        elif input_params['output_bit_depth'] == '16':
            m_els = m_els.astype(np.uint16)

        tifffile.imsave(batch_dir + '/' + UUID + '_mc.tiff',
                        m_els,
                        bigtiff=True,
                        imagej=True,
                        compress=1)

        output.update({'status': 1, 'bord_px': int(bord_px_els)})

    except Exception:
        output.update({'status': 0, 'output_info': traceback.format_exc()})

    for mf in glob(batch_dir + '/' + UUID + '*.mmap'):
        os.remove(mf)

    dview.terminate()

    end_time = time()
    processing_time = (end_time - start_time) / 60

    output_files_list = [UUID + '_mc.tiff', UUID + '.out']

    output.update({
        'processing_time': processing_time,
        'output_files': output_files_list
    })

    json.dump(output, open(file_path + '.out', 'w'))
Example #12
0
def caiman_main_light_weight(fr, fnames, z=0, dend=False):
    """
    Main function to compute the caiman algorithm. For more details see github and papers
    fpath(str): Folder where to store the plots
    fr(int): framerate
    fnames(list-str): list with the names of the files to be computed together
    z(array): vector with the values of z relative to y
    dend(bool): Boleean to change parameters to look for neurons or dendrites
    display_images(bool): to display and save different plots
    returns
    F_dff(array): array with the dff of the components
    com(array): matrix with the position values of the components as given by caiman
    cnm(struct): struct with different stimates and returns from caiman"""

    # parameters
    decay_time = 0.4  # length of a typical transient in seconds

    # Look for the best parameters for this 2p system and never change them again :)
    # motion correction parameters
    niter_rig = 1  # number of iterations for rigid motion correction
    max_shifts = (3, 3)  # maximum allow rigid shift
    splits_rig = 10  # for parallelization split the movies in  num_splits chuncks across time
    strides = (
        96, 96
    )  # start a new patch for pw-rigid motion correction every x pixels
    overlaps = (48, 48
                )  # overlap between pathes (size of patch strides+overlaps)
    splits_els = 10  # for parallelization split the movies in  num_splits chuncks across time
    upsample_factor_grid = 4  # upsample factor to avoid smearing when merging patches
    max_deviation_rigid = 3  # maximum deviation allowed for patch with respect to rigid shifts

    # parameters for source extraction and deconvolution
    p = 1  # order of the autoregressive system
    gnb = 2  # number of global background components
    merge_thresh = 0.8  # merging threshold, max correlation allowed
    rf = 25  # half-size of the patches in pixels. e.g., if rf=25, patches are 50x50
    stride_cnmf = 10  # amount of overlap between the patches in pixels
    K = 25  # number of components per patch

    if dend:
        gSig = [1, 1]  # expected half size of neurons
        init_method = 'sparse_nmf'  # initialization method (if analyzing dendritic data using 'sparse_nmf')
        alpha_snmf = 1e-6  # sparsity penalty for dendritic data analysis through sparse NMF
    else:
        gSig = [3, 3]  # expected half size of neurons
        init_method = 'greedy_roi'  # initialization method (if analyzing dendritic data using 'sparse_nmf')
        alpha_snmf = None  # sparsity penalty for dendritic data analysis through sparse NMF

    # parameters for component evaluation
    min_SNR = 2.5  # signal to noise ratio for accepting a component
    rval_thr = 0.8  # space correlation threshold for accepting a component
    cnn_thr = 0.8  # threshold for CNN based classifier

    dview = None  # parallel processing keeps crashing.

    print('***************Starting motion correction*************')
    print('files:')
    print(fnames)

    # %% start a cluster for parallel processing
    # c, dview, n_processes = cm.cluster.setup_cluster(backend='local', n_processes=None, single_thread=False)

    # %%% MOTION CORRECTION
    # first we create a motion correction object with the parameters specified
    min_mov = cm.load(fnames[0]).min()
    # this will be subtracted from the movie to make it non-negative

    mc = MotionCorrect(fnames,
                       min_mov,
                       dview=dview,
                       max_shifts=max_shifts,
                       niter_rig=niter_rig,
                       splits_rig=splits_rig,
                       strides=strides,
                       overlaps=overlaps,
                       splits_els=splits_els,
                       upsample_factor_grid=upsample_factor_grid,
                       max_deviation_rigid=max_deviation_rigid,
                       shifts_opencv=True,
                       nonneg_movie=True)
    # note that the file is not loaded in memory

    # %% Run piecewise-rigid motion correction using NoRMCorre
    mc.motion_correct_pwrigid(save_movie=True)
    bord_px_els = np.ceil(
        np.maximum(np.max(np.abs(mc.x_shifts_els)),
                   np.max(np.abs(mc.y_shifts_els)))).astype(np.int)

    totdes = [np.nansum(mc.x_shifts_els), np.nansum(mc.y_shifts_els)]
    print('***************Motion correction has ended*************')
    # maximum shift to be used for trimming against NaNs

    # %% MEMORY MAPPING
    # memory map the file in order 'C'
    fnames = mc.fname_tot_els  # name of the pw-rigidly corrected file.
    fname_new = cm.save_memmap(fnames,
                               base_name='memmap_',
                               order='C',
                               border_to_0=bord_px_els)  # exclude borders

    # now load the file
    Yr, dims, T = cm.load_memmap(fname_new)
    d1, d2 = dims
    images = np.reshape(Yr.T, [T] + list(dims), order='F')
    # load frames in python format (T x X x Y)

    # %% restart cluster to clean up memory
    # cm.stop_server(dview=dview)
    # c, dview, n_processes = cm.cluster.setup_cluster(backend='local', n_processes=None, single_thread=False)

    # %% RUN CNMF ON PATCHES
    print('***************Running CNMF...*************')

    # First extract spatial and temporal components on patches and combine them
    # for this step deconvolution is turned off (p=0)

    cnm = cnmf.CNMF(n_processes=1,
                    k=K,
                    gSig=gSig,
                    merge_thresh=merge_thresh,
                    p=0,
                    dview=dview,
                    rf=rf,
                    stride=stride_cnmf,
                    memory_fact=1,
                    method_init=init_method,
                    alpha_snmf=alpha_snmf,
                    only_init_patch=False,
                    gnb=gnb,
                    border_pix=bord_px_els)
    cnm = cnm.fit(images)

    # %% COMPONENT EVALUATION
    # the components are evaluated in three ways:
    #   a) the shape of each component must be correlated with the data
    #   b) a minimum peak SNR is required over the length of a transient
    #   c) each shape passes a CNN based classifier

    idx_components, idx_components_bad, SNR_comp, r_values, cnn_preds = \
        estimate_components_quality_auto(images, cnm.estimates.A, cnm.estimates.C, cnm.estimates.b,
                                         cnm.estimates.f,
                                         cnm.estimates.YrA, fr, decay_time, gSig, dims,
                                         dview=dview, min_SNR=min_SNR,
                                         r_values_min=rval_thr, use_cnn=False,
                                         thresh_cnn_min=cnn_thr)

    # %% RE-RUN seeded CNMF on accepted patches to refine and perform deconvolution
    A_in, C_in, b_in, f_in = cnm.estimates.A[:, idx_components], cnm.estimates.C[
        idx_components], cnm.estimates.b, cnm.estimates.f
    cnm2 = cnmf.CNMF(n_processes=1,
                     k=A_in.shape[-1],
                     gSig=gSig,
                     p=p,
                     dview=dview,
                     merge_thresh=merge_thresh,
                     Ain=A_in,
                     Cin=C_in,
                     b_in=b_in,
                     f_in=f_in,
                     rf=None,
                     stride=None,
                     gnb=gnb,
                     method_deconvolution='oasis',
                     check_nan=True)

    print('***************Fit*************')
    cnm2 = cnm2.fit(images)

    print('***************Extractind DFFs*************')
    # %% Extract DF/F values

    # cm.stop_server(dview=dview)
    try:
        F_dff = detrend_df_f(cnm2.estimates.A,
                             cnm2.estimates.b,
                             cnm2.estimates.C,
                             cnm2.estimates.f,
                             YrA=cnm2.estimates.YrA,
                             quantileMin=8,
                             frames_window=250)
        # F_dff = detrend_df_f(cnm.A, cnm.b, cnm.C, cnm.f, YrA=cnm.YrA, quantileMin=8, frames_window=250)
    except:
        F_dff = cnm2.estimates.C * np.nan
        print('WHAAT went wrong again?')

    print('***************stopping cluster*************')
    # %% STOP CLUSTER and clean up log files
    # cm.stop_server(dview=dview)

    # ***************************************************************************************
    # Preparing output data
    # F_dff  -> DFF values,  is a matrix [number of neurons, length recording]

    # com  --> center of mass,  is a matrix [number of neurons, 2]
    print('***************preparing output data*************')

    if len(dims) <= 2:
        if len(z) == 1:
            com = np.concatenate(
                (cm.base.rois.com(cnm2.estimates.A, dims[0], dims[1]),
                 np.zeros((cnm2.estimates.A.shape[1], 1)) + z), 1)
        elif len(z) == dims[0]:
            auxcom = cm.base.rois.com(cnm2.estimates.A, dims[0], dims[1])
            zy = np.zeros((auxcom.shape[0], 1))
            for y in np.arange(auxcom.shape[0]):
                zy[y, 0] = z[int(auxcom[y, 0])]
            com = np.concatenate((auxcom, zy), 1)
        else:
            print(
                'WARNING: Z value was not correctly defined, only X and Y values on file, z==zeros'
            )
            print(['length of z was: ' + str(len(z))])
            com = np.concatenate(
                (cm.base.rois.com(cnm2.estimates.A, dims[0], dims[1]),
                 np.zeros((cnm2.estimates.A.shape[1], 1))), 1)
    else:
        com = cm.base.rois.com(cnm2.estimates.A, dims[0], dims[1], dims[2])

    return F_dff, com, cnm2, totdes, SNR_comp[idx_components]
def run(batch_dir: str, UUID: str):
    logging.basicConfig(
        stream=sys.stdout,
        level=logging.DEBUG,
        format=
        "%(relativeCreated)12d [%(filename)s:%(funcName)20s():%(lineno)s] [%(process)d] %(message)s"
    )
    start_time = time()

    output = {'status': 0, 'output_info': ''}
    file_path = os.path.join(batch_dir, UUID)
    n_processes = os.environ['_MESMERIZE_N_THREADS']
    n_processes = int(n_processes)

    c, dview, n_processes = cm.cluster.setup_cluster(backend='local',
                                                     n_processes=n_processes,
                                                     single_thread=False,
                                                     ignore_preexisting=True)

    try:
        fname = [file_path + '_input.tiff']
        input_params = pickle.load(open(file_path + '.params', 'rb'))
        # TODO: Should just unpack the input params as kwargs
        mc_kwargs = input_params['mc_kwargs']

        splits_rig = n_processes

        splits_els = n_processes

        if os.environ['_MESMERIZE_USE_CUDA'] == 'True':
            USE_CUDA = True
        else:
            USE_CUDA = False

        min_mov = cm.load(fname[0], subindices=range(200)).min()

        mc = MotionCorrect(fname[0],
                           min_mov,
                           dview=dview,
                           splits_rig=splits_rig,
                           splits_els=splits_els,
                           shifts_opencv=True,
                           nonneg_movie=True,
                           use_cuda=USE_CUDA,
                           **mc_kwargs)

        mc.motion_correct_pwrigid(save_movie=True)
        m_els = cm.load(mc.fname_tot_els)
        bord_px_els = np.ceil(
            np.maximum(np.max(np.abs(mc.x_shifts_els)),
                       np.max(np.abs(mc.y_shifts_els)))).astype(np.int)

        m_els -= np.nanmin(m_els)

        if input_params['output_bit_depth'] == 'Do not convert':
            pass
        elif input_params['output_bit_depth'] == '8':
            m_els = m_els.astype(np.uint8, copy=False)
        elif input_params['output_bit_depth'] == '16':
            m_els = m_els.astype(np.uint16)

        img_out_path = os.path.join(batch_dir, f'{UUID}_mc.tiff')
        tifffile.imsave(img_out_path,
                        m_els,
                        bigtiff=True,
                        imagej=True,
                        compress=1)

        output.update({'status': 1, 'bord_px': int(bord_px_els)})

    except Exception:
        output.update({'status': 0, 'output_info': traceback.format_exc()})

    for mf in glob(os.path.join(batch_dir, UUID + '*.mmap')):
        try:
            os.remove(mf)
        except:
            pass

    dview.terminate()

    end_time = time()
    processing_time = (end_time - start_time) / 60

    output_files_list = [UUID + '_mc.tiff']

    output.update({
        'processing_time': processing_time,
        'output_files': output_files_list
    })

    json.dump(output, open(file_path + '.out', 'w'))
Example #14
0
def main():
    pass # For compatibility between running under Spyder and the CLI

#%% First setup some parameters for data and motion correction

    # dataset dependent parameters
    fname = ['Sue_2x_3000_40_-46.tif']  # filename to be processed
    fr = 30                             # imaging rate in frames per second
    decay_time = 0.4                    # length of a typical transient in seconds
    dxy = (2., 2.)                      # spatial resolution in x and y in (um per pixel)
    max_shift_um = (12., 12.)           # maximum shift in um
    patch_motion_um = (100., 100.)      # patch size for non-rigid motion correction in um
    
    # motion correction parameters
    pwrigid_motion_correct = True       # flag to select rigid vs pw_rigid motion correction
    max_shifts = tuple([int(a/b) for a, b in zip(max_shift_um, dxy)])                 
                                        # maximum allow rigid shift in pixels
    # for parallelization split the movies in  num_splits chuncks across time
    splits_rig = 56
    # start a new patch for pw-rigid motion correction every x pixels
    strides = tuple([int(a/b) for a, b in zip(patch_motion_um, dxy)])
    # overlap between pathes (size of patch strides+overlaps)
    overlaps = (24, 24)
    # for parallelization split the movies in  num_splits chuncks across time
    splits_els = 56
    upsample_factor_grid = 4    # upsample factor to avoid smearing when merging patches
    # maximum deviation allowed for patch with respect to rigid shifts
    max_deviation_rigid = 3

#%% download the dataset if it's not present in your folder
    if fname[0] in ['Sue_2x_3000_40_-46.tif', 'demoMovie.tif']:
        fname = [download_demo(fname[0])]

#%% play the movie
    # playing the movie using opencv. It requires loading the movie in memory.
    # To close the video press q
    display_images = False

    if display_images:
        m_orig = cm.load_movie_chain(fname)
        downsample_ratio = 0.2
        moviehandle = m_orig.resize(1, 1, downsample_ratio)
        moviehandle.play(q_max=99.5, fr=60, magnification=2)

#%% start a cluster for parallel processing
    c, dview, n_processes = cm.cluster.setup_cluster(
        backend='local', n_processes=None, single_thread=False)


#%%% MOTION CORRECTION
    # first we create a motion correction object with the parameters specified
    min_mov = cm.load(fname[0], subindices=range(200)).min()
    # this will be subtracted from the movie to make it non-negative

    mc = MotionCorrect(fname, min_mov, dview=dview, max_shifts=max_shifts, 
                       splits_rig=splits_rig,
                       strides=strides, overlaps=overlaps, 
                       splits_els=splits_els, border_nan='copy',
                       upsample_factor_grid=upsample_factor_grid,
                       max_deviation_rigid=max_deviation_rigid,
                       shifts_opencv=True, nonneg_movie=True)
    # note that the file is not loaded in memory

#%% Run piecewise-rigid motion correction using NoRMCorre
    if pwrigid_motion_correct:
        mc.motion_correct_pwrigid(save_movie=True)
        m_els = cm.load(mc.fname_tot_els)
        bord_px_els = np.ceil(np.maximum(np.max(np.abs(mc.x_shifts_els)),
                                     np.max(np.abs(mc.y_shifts_els)))).astype(np.int)
        fnames = mc.fname_tot_els  # name of the pw-rigidly corrected file.

    else:
        mc.motion_correct_rigid(save_movie=True)
        m_els = cm.load(mc.fname_tot_rig)
        bord_px_els = np.ceil(np.max(np.abs(mc.shifts_rig))).astype(np.int)
        fnames = mc.fname_tot_rig  # name of the rigidly corrected file.

    # maximum shift to be used for trimming against NaNs
#%% compare with original movie
    if display_images:
        downsample_ratio = 0.2
        moviehandle = cm.concatenate([m_orig.resize(1, 1, downsample_ratio) - min_mov,
                                      m_els.resize(1, 1, downsample_ratio)],
                                     axis=2)
        moviehandle.play(fr=60, q_max=99.5, magnification=2)  # press q to exit

#%% MEMORY MAPPING
    # memory map the file in order 'C'
    border_to_0 = bord_px_els  # exclude borders due to motion correction
#    border_to_0 = 0 if mc.border_nan is 'copy' else bord_px_els   
        # you can include boundaries if you used the 'copy' option in the motion
        # correction, although be careful abou the components near the boundaries
    fname_new = cm.save_memmap(fnames, base_name='memmap_', order='C',
                               border_to_0=border_to_0)  # exclude borders

    # now load the file
    Yr, dims, T = cm.load_memmap(fname_new)
    images = np.reshape(Yr.T, [T] + list(dims), order='F')
    # load frames in python format (T x X x Y)

#%% restart cluster to clean up memory
    cm.stop_server(dview=dview)
    c, dview, n_processes = cm.cluster.setup_cluster(
        backend='local', n_processes=None, single_thread=False)

#%%  parameters for source extraction and deconvolution
    p = 1                       # order of the autoregressive system
    gnb = 2                     # number of global background components
    merge_thresh = 0.8          # merging threshold, max correlation allowed
    # half-size of the patches in pixels. e.g., if rf=25, patches are 50x50
    rf = 15
    stride_cnmf = 6             # amount of overlap between the patches in pixels
    K = 4                       # number of components per patch
    gSig = [4, 4]               # expected half size of neurons
    # initialization method (if analyzing dendritic data using 'sparse_nmf')
    method_init = 'greedy_roi'

    # parameters for component evaluation

    opts = params.CNMFParams(dims=dims, fr=fr, decay_time=decay_time,
                             method_init=method_init, gSig=gSig,
                             merge_thresh=merge_thresh, p=p, gnb=gnb, k=K,
                             rf=rf, stride=stride_cnmf, rolling_sum=True)

#%% RUN CNMF ON PATCHES

    # First extract spatial and temporal components on patches and combine them
    # for this step deconvolution is turned off (p=0)

    opts.set('temporal', {'p': 0})
    cnm = cnmf.CNMF(n_processes, params=opts, dview=dview)
    cnm = cnm.fit(images)

#%% plot contours of found components
    Cn = cm.local_correlations(images.transpose(1, 2, 0))
    Cn[np.isnan(Cn)] = 0
    cnm.estimates.plot_contours(img=Cn)
    plt.title('Contour plots of found components')

#%% COMPONENT EVALUATION
    # the components are evaluated in three ways:
    #   a) the shape of each component must be correlated with the data
    #   b) a minimum peak SNR is required over the length of a transient
    #   c) each shape passes a CNN based classifier
    min_SNR = 2.5       # signal to noise ratio for accepting a component
    rval_thr = 0.8      # space correlation threshold for accepting a component
    cnn_thr = 0.8       # threshold for CNN based classifier
    cnm.params.set('quality', {'fr': fr,
                               'decay_time': decay_time,
                               'min_SNR': min_SNR,
                               'rval_thr': rval_thr,
                               'use_cnn': True,
                               'min_cnn_thr': cnn_thr})
    cnm.estimates.evaluate_components(images, cnm.params, dview=dview)

#%% PLOT COMPONENTS
    cnm.estimates.plot_contours(img=Cn, idx=cnm.estimates.idx_components)

#%% VIEW TRACES (accepted and rejected)

    if display_images:
        cnm.estimates.view_components(images, img=Cn, idx=cnm.estimates.idx_components)
        cnm.estimates.view_components(images, img=Cn, idx=cnm.estimates.idx_components_bad)

#%% RE-RUN seeded CNMF on accepted patches to refine and perform deconvolution

    cnm.dview = None
    cnm2 = deepcopy(cnm)
    cnm2.dview = dview
    cnm2.params.set('patch', {'rf': None})
    cnm2.params.set('temporal', {'p': p})
    cnm2 = cnm2.fit(images)

#%% Extract DF/F values
    cnm2.estimates.detrend_df_f(quantileMin=8, frames_window=250)

#%% Show final traces
    cnm2.estimates.view_components(Yr, img=Cn)

#%% reconstruct denoised movie (press q to exit)
    if display_images:
        cnm2.estimates.play_movie(images, q_max=99.9, gain_res=2,
                                  magnification=2,
                                  bpx=border_to_0,
                                  include_bck=True)

#%% STOP CLUSTER and clean up log files
    cm.stop_server(dview=dview)
    log_files = glob.glob('*_LOG_*')
    for log_file in log_files:
        os.remove(log_file)
Example #15
0
def main():
    pass  # For compatibility between running under Spyder and the CLI

    #%% First setup some parameters

    # dataset dependent parameters
    display_images = False  # Set this to true to show movies and plots
    fname = ['Sue_2x_3000_40_-46.tif']  # filename to be processed
    fr = 30  # imaging rate in frames per second
    decay_time = 0.4  # length of a typical transient in seconds

    # motion correction parameters
    niter_rig = 1  # number of iterations for rigid motion correction
    max_shifts = (6, 6)  # maximum allow rigid shift
    # for parallelization split the movies in  num_splits chuncks across time
    splits_rig = 56
    # start a new patch for pw-rigid motion correction every x pixels
    strides = (48, 48)
    # overlap between pathes (size of patch strides+overlaps)
    overlaps = (24, 24)
    # for parallelization split the movies in  num_splits chuncks across time
    splits_els = 56
    upsample_factor_grid = 4  # upsample factor to avoid smearing when merging patches
    # maximum deviation allowed for patch with respect to rigid shifts
    max_deviation_rigid = 3

    # parameters for source extraction and deconvolution
    p = 1  # order of the autoregressive system
    gnb = 2  # number of global background components
    merge_thresh = 0.8  # merging threshold, max correlation allowed
    # half-size of the patches in pixels. e.g., if rf=25, patches are 50x50
    rf = 15
    stride_cnmf = 6  # amount of overlap between the patches in pixels
    K = 4  # number of components per patch
    gSig = [4, 4]  # expected half size of neurons
    # initialization method (if analyzing dendritic data using 'sparse_nmf')
    init_method = 'greedy_roi'
    is_dendrites = False  # flag for analyzing dendritic data
    # sparsity penalty for dendritic data analysis through sparse NMF
    alpha_snmf = None

    # parameters for component evaluation
    min_SNR = 2.5  # signal to noise ratio for accepting a component
    rval_thr = 0.8  # space correlation threshold for accepting a component
    cnn_thr = 0.8  # threshold for CNN based classifier

    #%% download the dataset if it's not present in your folder
    if fname[0] in ['Sue_2x_3000_40_-46.tif', 'demoMovie.tif']:
        fname = [download_demo(fname[0])]

#%% play the movie
# playing the movie using opencv. It requires loading the movie in memory. To
# close the video press q

    m_orig = cm.load_movie_chain(fname[:1])
    downsample_ratio = 0.2
    offset_mov = -np.min(m_orig[:100])
    moviehandle = m_orig.resize(1, 1, downsample_ratio)
    if display_images:
        moviehandle.play(gain=10, offset=offset_mov, fr=30, magnification=2)

#%% start a cluster for parallel processing
    c, dview, n_processes = cm.cluster.setup_cluster(backend='local',
                                                     n_processes=None,
                                                     single_thread=False)

    #%%% MOTION CORRECTION
    # first we create a motion correction object with the parameters specified
    min_mov = cm.load(fname[0], subindices=range(200)).min()
    # this will be subtracted from the movie to make it non-negative

    mc = MotionCorrect(fname[0],
                       min_mov,
                       dview=dview,
                       max_shifts=max_shifts,
                       niter_rig=niter_rig,
                       splits_rig=splits_rig,
                       strides=strides,
                       overlaps=overlaps,
                       splits_els=splits_els,
                       upsample_factor_grid=upsample_factor_grid,
                       max_deviation_rigid=max_deviation_rigid,
                       shifts_opencv=True,
                       nonneg_movie=True)
    # note that the file is not loaded in memory

    #%% Run piecewise-rigid motion correction using NoRMCorre
    mc.motion_correct_pwrigid(save_movie=True)
    m_els = cm.load(mc.fname_tot_els)
    bord_px_els = np.ceil(
        np.maximum(np.max(np.abs(mc.x_shifts_els)),
                   np.max(np.abs(mc.y_shifts_els)))).astype(np.int)
    # maximum shift to be used for trimming against NaNs
    #%% compare with original movie
    moviehandle = cm.concatenate([
        m_orig.resize(1, 1, downsample_ratio) + offset_mov,
        m_els.resize(1, 1, downsample_ratio)
    ],
                                 axis=2)
    display_images = False
    if display_images:
        moviehandle.play(fr=60, q_max=99.5, magnification=2,
                         offset=0)  # press q to exit

#%% MEMORY MAPPING
# memory map the file in order 'C'
    fnames = mc.fname_tot_els  # name of the pw-rigidly corrected file.
    border_to_0 = bord_px_els  # number of pixels to exclude
    fname_new = cm.save_memmap(fnames,
                               base_name='memmap_',
                               order='C',
                               border_to_0=bord_px_els)  # exclude borders

    # now load the file
    Yr, dims, T = cm.load_memmap(fname_new)
    d1, d2 = dims
    images = np.reshape(Yr.T, [T] + list(dims), order='F')
    # load frames in python format (T x X x Y)

    #%% restart cluster to clean up memory
    cm.stop_server(dview=dview)
    c, dview, n_processes = cm.cluster.setup_cluster(backend='local',
                                                     n_processes=None,
                                                     single_thread=False)

    #%% RUN CNMF ON PATCHES

    # First extract spatial and temporal components on patches and combine them
    # for this step deconvolution is turned off (p=0)
    t1 = time.time()

    cnm = cnmf.CNMF(n_processes=1,
                    k=K,
                    gSig=gSig,
                    merge_thresh=merge_thresh,
                    p=0,
                    dview=dview,
                    rf=rf,
                    stride=stride_cnmf,
                    memory_fact=1,
                    method_init=init_method,
                    alpha_snmf=alpha_snmf,
                    only_init_patch=False,
                    gnb=gnb,
                    border_pix=bord_px_els)
    cnm = cnm.fit(images)

    #%% plot contours of found components
    Cn = cm.local_correlations(images.transpose(1, 2, 0))
    Cn[np.isnan(Cn)] = 0
    plt.figure()
    crd = plot_contours(cnm.A, Cn, thr=0.9)
    plt.title('Contour plots of found components')

    #%% COMPONENT EVALUATION
    # the components are evaluated in three ways:
    #   a) the shape of each component must be correlated with the data
    #   b) a minimum peak SNR is required over the length of a transient
    #   c) each shape passes a CNN based classifier

    idx_components, idx_components_bad, SNR_comp, r_values, cnn_preds = \
        estimate_components_quality_auto(images, cnm.A, cnm.C, cnm.b, cnm.f,
                                         cnm.YrA, fr, decay_time, gSig, dims,
                                         dview=dview, min_SNR=min_SNR,
                                         r_values_min=rval_thr, use_cnn=False,
                                         thresh_cnn_min=cnn_thr)

    #%% PLOT COMPONENTS

    if display_images:
        plt.figure()
        plt.subplot(121)
        crd_good = cm.utils.visualization.plot_contours(cnm.A[:,
                                                              idx_components],
                                                        Cn,
                                                        thr=.8,
                                                        vmax=0.75)
        plt.title('Contour plots of accepted components')
        plt.subplot(122)
        crd_bad = cm.utils.visualization.plot_contours(
            cnm.A[:, idx_components_bad], Cn, thr=.8, vmax=0.75)
        plt.title('Contour plots of rejected components')

#%% VIEW TRACES (accepted and rejected)

    if display_images:
        view_patches_bar(Yr,
                         cnm.A.tocsc()[:, idx_components],
                         cnm.C[idx_components],
                         cnm.b,
                         cnm.f,
                         dims[0],
                         dims[1],
                         YrA=cnm.YrA[idx_components],
                         img=Cn)

        view_patches_bar(Yr,
                         cnm.A.tocsc()[:, idx_components_bad],
                         cnm.C[idx_components_bad],
                         cnm.b,
                         cnm.f,
                         dims[0],
                         dims[1],
                         YrA=cnm.YrA[idx_components_bad],
                         img=Cn)

#%% RE-RUN seeded CNMF on accepted patches to refine and perform deconvolution
    A_in, C_in, b_in, f_in = cnm.A[:, idx_components], cnm.C[
        idx_components], cnm.b, cnm.f
    cnm2 = cnmf.CNMF(n_processes=1,
                     k=A_in.shape[-1],
                     gSig=gSig,
                     p=p,
                     dview=dview,
                     merge_thresh=merge_thresh,
                     Ain=A_in,
                     Cin=C_in,
                     b_in=b_in,
                     f_in=f_in,
                     rf=None,
                     stride=None,
                     gnb=gnb,
                     method_deconvolution='oasis',
                     check_nan=True)

    cnm2 = cnm2.fit(images)

    #%% Extract DF/F values

    F_dff = detrend_df_f(cnm2.A,
                         cnm2.b,
                         cnm2.C,
                         cnm2.f,
                         YrA=cnm2.YrA,
                         quantileMin=8,
                         frames_window=250)

    #%% Show final traces
    cnm2.view_patches(Yr, dims=dims, img=Cn)

    #%% STOP CLUSTER and clean up log files
    cm.stop_server(dview=dview)
    log_files = glob.glob('*_LOG_*')
    for log_file in log_files:
        os.remove(log_file)

#%% reconstruct denoised movie
    denoised = cm.movie(cnm2.A.dot(cnm2.C) + cnm2.b.dot(cnm2.f)).reshape(
        dims + (-1, ), order='F').transpose([2, 0, 1])

    #%% play along side original data
    moviehandle = cm.concatenate([
        m_els.resize(1, 1, downsample_ratio),
        denoised.resize(1, 1, downsample_ratio)
    ],
                                 axis=2)
    if display_images:
        moviehandle.play(fr=60, gain=15, magnification=2,
                         offset=0)  # press q to exit
# %% plot rigid shifts
pl.close()
pl.plot(mc.shifts_rig)
pl.legend(['x shifts', 'y shifts'])
pl.xlabel('frames')
pl.ylabel('pixels')
# %% inspect movie
downsample_ratio = params_display['downsample_ratio']
# TODO: todocument
offset_mov = -np.min(m_orig[:100])
m_rig.resize(1, 1, downsample_ratio).play(
    gain=10, offset=offset_mov * .25, fr=30, magnification=2, bord_px=bord_px_rig)
# %%
# a computing intensive but parralellized part
t1 = time.time()
mc.motion_correct_pwrigid(save_movie=True,
                          template=mc.total_template_rig, show_template=True)
# TODO: change var name els= pwr
m_els = cm.load(mc.fname_tot_els)
pl.imshow(mc.total_template_els, cmap='gray')
# TODO: show screenshot 5
# TODO: bug sometimes saying there is no y_shifts_els
bord_px_els = np.ceil(np.maximum(np.max(np.abs(mc.x_shifts_els)),
                                 np.max(np.abs(mc.y_shifts_els)))).astype(np.int)
# %% visualize elastic shifts
pl.close()
pl.subplot(2, 1, 1)
pl.plot(mc.x_shifts_els)
pl.ylabel('x shifts (pixels)')
pl.subplot(2, 1, 2)
pl.plot(mc.y_shifts_els)
pl.ylabel('y_shifts (pixels)')
def main(index, row, parameters, dview):
    '''
    This is the function for motion correction. Its goal is to take in a decoded and
    cropped .tif file, perform motion correction, and save the result as a .mmap file. 
    
    This function is only runnable on the cn76 server because it requires parralel processing. 
    
    Args:
        index: tuple
            The index of the analysis state to be motion corrected. 
        row: pd.DataFrame object
            The row corresponding to the analysis state to be motion corrected. 
            
    Returns:
        index: tuple
            The index of the motion corrected analysis state. 
        row: pd.DataFrame object
            The row corresponding to the motion corrected analysis state.      
    '''

    # Forcing parameters
    if not parameters['pw_rigid']:
        parameters['save_movie_rig'] = True

    # Get input file
    input_tif_file_path = eval(row.loc['cropping_output'])['main']
    if not os.path.isfile(input_tif_file_path):
        input_tif_file_path = src.pipeline.get_expected_file_path(
            'cropping', index, 'main/', 'tif')
        if not os.path.isfile(input_tif_file_path):
            logging.error(
                'Cropping file not found. Cancelling motion correction.')
            return index, row

    # Get output file paths
    data_dir = 'data/interim/motion_correction/'
    file_name = src.pipeline.create_file_name(step_index, index)
    output_meta_pkl_file_path = data_dir + f'meta/metrics/{file_name}.pkl'

    # Create a dictionary with the output
    output = {
        'meta': {
            'analysis': {
                'analyst': os.environ['ANALYST'],
                'date': datetime.datetime.today().strftime("%m-%d-%Y"),
                'time': datetime.datetime.today().strftime("%H:%M:%S")
            },
            'metrics': {
                'other': output_meta_pkl_file_path
            }
        }
    }

    row.loc['motion_correction_parameters'] = str(parameters)

    # Calculate movie minimum to subtract from movie
    min_mov = np.min(cm.load(input_tif_file_path))
    # Apply the parameters to the CaImAn algorithm
    caiman_parameters = parameters.copy()
    caiman_parameters['min_mov'] = min_mov
    opts = params.CNMFParams(params_dict=caiman_parameters)

    # Rigid motion correction (in both cases)
    logging.info(f'{index} Performing rigid motion correction')
    t0 = datetime.datetime.today()

    # Create a MotionCorrect object
    mc = MotionCorrect([input_tif_file_path],
                       dview=dview,
                       **opts.get_group('motion'))
    # Perform rigid motion correction
    mc.motion_correct_rigid(save_movie=parameters['save_movie_rig'],
                            template=None)
    dt = int(
        (datetime.datetime.today() - t0).seconds / 60)  # timedelta in minutes
    logging.info(f'{index} Rigid motion correction finished. dt = {dt} min')
    # Obtain template, rigid shifts and border pixels
    total_template_rig = mc.total_template_rig
    shifts_rig = mc.shifts_rig
    # Save template, rigid shifts and border pixels in a dictionary
    meta_pkl_dict = {
        'rigid': {
            'template': total_template_rig,
            'shifts': shifts_rig,
        }
    }
    output['meta']['duration'] = {'rigid': dt}

    if parameters['save_movie_rig']:
        # Load the movie saved by CaImAn, which is in the wrong
        # directory and is not yet cropped
        logging.info(f'{index} Loading rigid movie for cropping')
        m_rig = cm.load(mc.fname_tot_rig[0])
        logging.info(f'{index} Loaded rigid movie for cropping')
        # Get the cropping points determined by the maximal rigid shifts
        x_, _x, y_, _y = get_crop_from_rigid_shifts(shifts_rig)
        # Crop the movie
        logging.info(
            f'{index} Cropping and saving rigid movie with cropping points: [x_, _x, y_, _y] = {[x_, _x, y_, _y]}'
        )
        m_rig = m_rig.crop(x_, _x, y_, _y, 0, 0)
        meta_pkl_dict['rigid']['cropping_points'] = [x_, _x, y_, _y]
        # Save the movie
        rig_role = 'alternate' if parameters['pw_rigid'] else 'main'
        fname_tot_rig = m_rig.save(data_dir + rig_role + '/' + file_name +
                                   '_rig' + '.mmap',
                                   order='C')
        logging.info(
            f'{index} Cropped and saved rigid movie as {fname_tot_rig}')
        # Store the total path in output
        output[rig_role] = fname_tot_rig
        # Remove the remaining non-cropped movie
        os.remove(mc.fname_tot_rig[0])

    # If specified in the parameters, apply piecewise-rigid motion correction
    if parameters['pw_rigid']:
        logging.info(f'{index} Performing piecewise-rigid motion correction')
        t0 = datetime.datetime.today()
        # Perform non-rigid (piecewise rigid) motion correction. Use the rigid result as a template.
        mc.motion_correct_pwrigid(save_movie=True, template=total_template_rig)
        # Obtain template and filename
        total_template_els = mc.total_template_els
        fname_tot_els = mc.fname_tot_els[0]

        dt = int((datetime.datetime.today() - t0).seconds /
                 60)  # timedelta in minutes
        meta_pkl_dict['pw_rigid'] = {
            'template': total_template_els,
            'x_shifts': mc.x_shifts_els,
            'y_shifts': mc.
            y_shifts_els  # removed them initially because they take up space probably
        }
        output['meta']['duration']['pw_rigid'] = dt
        logging.info(
            f'{index} Piecewise-rigid motion correction finished. dt = {dt} min'
        )

        # Load the movie saved by CaImAn, which is in the wrong
        # directory and is not yet cropped
        logging.info(f'{index} Loading pw-rigid movie for cropping')
        m_els = cm.load(fname_tot_els)
        logging.info(f'{index} Loaded pw-rigid movie for cropping')
        # Get the cropping points determined by the maximal rigid shifts
        x_, _x, y_, _y = get_crop_from_pw_rigid_shifts(
            np.array(mc.x_shifts_els), np.array(mc.y_shifts_els))
        # Crop the movie
        logging.info(
            f'{index} Cropping and saving pw-rigid movie with cropping points: [x_, _x, y_, _y] = {[x_, _x, y_, _y]}'
        )
        m_els = m_els.crop(x_, _x, y_, _y, 0, 0)
        meta_pkl_dict['pw_rigid']['cropping_points'] = [x_, _x, y_, _y]
        # Save the movie
        fname_tot_els = m_els.save(data_dir + 'main/' + file_name + '_els' +
                                   '.mmap',
                                   order='C')
        logging.info(
            f'{index} Cropped and saved rigid movie as {fname_tot_els}')

        # Remove the remaining non-cropped movie
        os.remove(mc.fname_tot_els[0])

        # Store the total path in output
        output['main'] = fname_tot_els

    # Write meta results dictionary to the pkl file
    pkl_file = open(output_meta_pkl_file_path, 'wb')
    pickle.dump(meta_pkl_dict, pkl_file)
    pkl_file.close()

    # Write necessary variables to the trial index and row
    row.loc['motion_correction_output'] = str(output)
    row.loc['motion_correction_parameters'] = str(parameters)

    # Compute the basic metrics 'crispness'
    get_metrics(index, row, crispness=True)

    # Create source extraction images in advance:
    logging.info(f'{index} Creating corr and pnr images in advance')
    index, row = src.steps.source_extraction.get_corr_pnr(index, row)
    logging.info(f'{index} Created corr and pnr images')

    return index, row