예제 #1
0
파일: nonrigid.py 프로젝트: ogeesan/suite2p
def prepare_masks(refImg1, ops):
    refImg0 = refImg1.copy()
    if ops['1Preg']:
        maskSlope = ops['spatial_taper']
    else:
        maskSlope = 3 * ops['smooth_sigma']  # slope of taper mask at the edges
    Ly, Lx = refImg0.shape
    maskMul = register.spatial_taper(maskSlope, Ly, Lx)

    if ops['1Preg']:
        refImg0 = register.one_photon_preprocess(refImg0[np.newaxis, :, :],
                                                 ops).squeeze()

    # split refImg0 into multiple parts
    cfRefImg1 = []
    maskMul1 = []
    maskOffset1 = []
    nb = len(ops['yblock'])

    #patch taper
    Ly = ops['yblock'][0][1] - ops['yblock'][0][0]
    Lx = ops['xblock'][0][1] - ops['xblock'][0][0]
    if ops['pad_fft']:
        cfRefImg1 = np.zeros((nb, 1, next_fast_len(Ly), next_fast_len(Lx)),
                             'complex64')
    else:
        cfRefImg1 = np.zeros((nb, 1, Ly, Lx), 'complex64')
    maskMul1 = np.zeros((nb, 1, Ly, Lx), 'float32')
    maskOffset1 = np.zeros((nb, 1, Ly, Lx), 'float32')
    for n in range(nb):
        yind = ops['yblock'][n]
        yind = np.arange(yind[0], yind[-1]).astype('int')
        xind = ops['xblock'][n]
        xind = np.arange(xind[0], xind[-1]).astype('int')

        refImg = refImg0[np.ix_(yind, xind)]
        maskMul2 = register.spatial_taper(2 * ops['smooth_sigma'], Ly, Lx)
        maskMul1[n, 0, :, :] = maskMul[np.ix_(yind, xind)].astype('float32')
        maskMul1[n, 0, :, :] *= maskMul2.astype('float32')
        maskOffset1[n,
                    0, :, :] = (refImg.mean() *
                                (1. - maskMul1[n, 0, :, :])).astype(np.float32)
        cfRefImg = np.conj(fft.fft2(refImg))
        absRef = np.absolute(cfRefImg)
        cfRefImg = cfRefImg / (eps0 + absRef)

        # gaussian filter
        fhg = register.gaussian_fft(ops['smooth_sigma'], cfRefImg.shape[0],
                                    cfRefImg.shape[1])
        cfRefImg *= fhg

        cfRefImg1[n, 0, :, :] = (cfRefImg.astype('complex64'))
    return maskMul1, maskOffset1, cfRefImg1
예제 #2
0
def phasecorr(data, refAndMasks, ops):
    t0=tic()
    ''' loop through blocks and compute phase correlations'''
    nimg, Ly, Lx = data.shape
    maskMul    = refAndMasks[0].squeeze()
    maskOffset = refAndMasks[1].squeeze()
    cfRefImg   = refAndMasks[2].squeeze()

    LyMax = np.diff(np.array(ops['yblock']))
    ly,lx = cfRefImg.shape[-2:]
    lyhalf = int(np.floor(ly/2))
    lxhalf = int(np.floor(lx/2))

    # maximum registration shift allowed
    maxregshift = np.round(ops['maxregshiftNR'])
    lcorr = int(np.minimum(maxregshift, np.floor(np.minimum(ly,lx)/2.)-lpad))
    nb = len(ops['yblock'])
    nblocks = ops['nblocks']

    # preprocessing for 1P recordings
    if ops['1Preg']:
        X = register.one_photon_preprocess(data.copy().astype(np.float32), ops)

    # shifts and corrmax
    ymax1 = np.zeros((nimg,nb),np.float32)
    cmax1 = np.zeros((nimg,nb),np.float32)
    xmax1 = np.zeros((nimg,nb),np.float32)

    cc0 = np.zeros((nimg, nb, 2*lcorr + 2*lpad + 1, 2*lcorr + 2*lpad + 1), np.float32)
    ymax = np.zeros((nb,), np.int32)
    xmax = np.zeros((nb,), np.int32)

    Y = np.zeros((nimg, nb, ly, lx), 'int16')
    for n in range(nb):
        yind, xind = ops['yblock'][n], ops['xblock'][n]
        Y[:,n] = data[:, yind[0]:yind[-1], xind[0]:xind[-1]]
    Y = addmultiply(Y, maskMul, maskOffset)
    for n in range(nb):
        for t in range(nimg):
            fft2(Y[t,n], overwrite_x=True)
    Y = register.apply_dotnorm(Y, cfRefImg)
    for n in range(nb):
        for t in range(nimg):
            ifft2(Y[t,n], overwrite_x=True)
    x00, x01, x10, x11 = my_clip(Y, lcorr+lpad)
    cc0 = np.real(np.block([[x11, x10], [x01, x00]]))
    cc0 = np.transpose(cc0, (1,0,2,3))
    cc0 = cc0.reshape((cc0.shape[0], -1))
    cc2 = []
    R = ops['NRsm']
    cc2.append(cc0)
    for j in range(2):
        cc2.append(R @ cc2[j])
    for j in range(len(cc2)):
        cc2[j] = cc2[j].reshape((nb, nimg, 2*lcorr+2*lpad+1, 2*lcorr+2*lpad+1))
    ccsm = cc2[0]
    for n in range(nb):
        snr = np.ones((nimg,), 'float32')
        for j in range(len(cc2)):
            ism = snr<ops['snr_thresh']
            if np.sum(ism)==0:
                break
            cc = cc2[j][n,ism,:,:]
            if j>0:
                ccsm[n, ism, :, :] = cc
            snr[ism] = getSNR(cc, (lcorr,lpad), ops)

    ccmat = np.zeros((nb, 2*lpad+1, 2*lpad+1), np.float32)
    for t in range(nimg):
        ccmat = np.zeros((nb, 2*lpad+1, 2*lpad+1), np.float32)
        for n in range(nb):
            ix = np.argmax(ccsm[n, t][lpad:-lpad, lpad:-lpad], axis=None)
            ym, xm = np.unravel_index(ix, (2*lcorr+1, 2*lcorr+1))
            ccmat[n] = ccsm[n,t][ym:ym+2*lpad+1, xm:xm+2*lpad+1]
            ymax[n], xmax[n] = ym-lcorr, xm-lcorr
        ccmat = np.reshape(ccmat, (nb,-1))
        ccb = np.dot(ccmat, Kmat)
        imax = np.argmax(ccb, axis=1)
        cmax = np.amax(ccb, axis=1)
        ymax1[t], xmax1[t] = np.unravel_index(imax, (nup,nup))
        mdpt = np.floor(nup/2)
        ymax1[t], xmax1[t] = (ymax1[t] - mdpt)/subpixel, (xmax1[t] - mdpt)/subpixel
        ymax1[t], xmax1[t] = ymax1[t] + ymax, xmax1[t] + xmax
    return ymax1, xmax1, cmax1