Exemple #1
0
def VTDM_prepb(spikefile, Dsfilename, dirichfilename, Mx, My = None, 
               domain=None, Wx = None, Wy = None, dtype = np.complex128):
    """
    Prepare decoding with two files Dswfilename and dirichfilename
    must be called before VTDM a new spikefile
    parameters:
    spikefile: the file generated by VTEM containing spike info
    Dsfilename: name of the resulting file, containing the spatial
                dirichlet coefficients for all RFs

    dirichfilename: the name of the resulting file, containing the
                    spatial reconstruction function for all RFs

    Mx: order of dirichlet space in x variable
    My: order of dirichlet space in y variable
        if not specified, My = Mx

    domain: 4-list or 4-tuple, [xstart, xend, ystart, yend]
            the spatial domain to recover
    Wx: bandwidth in x variable
        if not specified, will use the info in spikefile
    Wy: bandwidth in y variable
        if not specified, will use the info in spikefile

    dtype: np.complex128 or np.complex64, accurancy of computing
            dirichlet coefficients and the output files will
            be in the real format derived from dtype

    
    The coordinate system is given by the following:
        
        
        row (width / X) major
        -----------------------------------------> width
        |   Y
        |   ^-------------------->
        |   |-------------------->
        |   |-------------------->
        |   |--------------------> X
        | 
        |
        v
        height
        
    VTDM_prep computes inner product between two RFs and store in Dswfilename. 
    VTDM_prepb computes only the dirichlet coefficient of each RF and stores it in 
    Dsfilename. The inner product is computed in VTDMb.

    """

    if dtype not in [np.complex128, np.complex64]:
        raise TypeError("dtype must be complex128 or complex64")

    spfile = ss.vtem_storage_read(spikefile)

    spfile.read_vrf_type()
    
    num_neurons = spfile.read_num_neurons()
    print "total neurons: %d" % (num_neurons)
    
    Wt,Wx1,Wy1,Px,Py,dx,dy = spfile.read_video_attributes()

    if Wx is None:
        Wx = Wx1
    else:
        Wx = float(Wx)
    if Wy is None:
        Wy = Wy1
    else:
        Wy = float(Wy)

    Mx = int(Mx)
    if My is None:
        My = Mx
    else:
        My = int(My)

    rfSx = (2*np.pi) * (Mx/Wx)
    rfSy = (2*np.pi) * (My/Wy)
    rfPx = int(np.ceil(rfSx/dx/16) * 16)
    rfPy = int(np.ceil(rfSy/dy/16) * 16)
    
    oriSx = Px*dx
    oriSy = Py*dy
    
    if domain is None:
        domain = [-oriSx/2,oriSx/2,-oriSy/2,oriSy/2]
    else:
        domain = [max(float(domain[0]),-oriSx/2), 
                  min(float(domain[1]), oriSx/2), 
                  max(float(domain[2]),-oriSy/2),
                  min(float(domain[3]),oriSy/2)]

    spfile.select_neurons(domain)
    print "select neurons: %d" % (spfile.decode_neurons)

    decSx = domain[1]-domain[0]
    decSy = domain[3]-domain[2]

    decPx = int(np.round(decSx / dx))
    decPy = int(np.round(decSy / dy))


    if spfile.filter_type == "gabor":
        gb = vrf.vrf_gabor((rfPy, rfPx), dx=rfSx/rfPx, dy=rfSy/rfPy, 
                           scale=4, dtype=dtype)
        gb.load_parameters(num_neurons=spfile.decode_neurons, 
                           h_alpha=spfile.alpha, h_l=spfile.l, h_x0=spfile.x0, 
                           h_y0=spfile.y0, h_ab=spfile.ab, KAPPA=spfile.KAPPA)
    else:
        gb = vrf.vrf_cs((rfPy, rfPx), dx=rfSx/rfPx, dy=rfSy/rfPy, 
        	        scale=4, dtype=dtype)
        gb.load_parameters(num_neurons=spfile.decode_neurons, 
                           h_alpha=spfile.alpha, h_x0=spfile.x0, h_y0=spfile.y0,
        		   sigma_center=spfile.sigma_center, 
        		   sigma_surround=spfile.sigma_surround)
    
    d_Ds = gb.compute_Ds(Mx, My)
    write_memory_to_file(d_Ds, Dsfilename)

    d_dirich = gb.compute_dirich_space_fft(d_Ds, Mx, My, decPx, 
                                           decPy, decSx, decSy, Wx, Wy)
    
    write_memory_to_file(d_dirich, dirichfilename)
    del d_dirich

    spfile.close()
Exemple #2
0
def VTDMb(spikefile, Dsfilename, dirichfilename, start_time, end_time,
         dt, Mx, My=None, Mt=None, domain=None, Wx=None, Wy=None, Wt=None, 
         lamb=0.0, dtype=np.float64, rnn=False, alpha=50000,steps=4000, 
         stitching=False, stitch_interval=None, output_format=0,
         output="rec"):
    """
    Reconstruct video using VTDM with dirichlet kernel, assuming IAF neurons
    must call VTDM_prepb before using this for differernt Mx,My for each spikefile

-------------------------------------------------------------------
    
    parameters:

    Required:
    spikefile:the file generated by VTEM containing spike info
    Dswfilename: file generated by VTDM_prep
    dirichfilename: file generated by VTDM_prep
    start_time: the starting time of the segment to reconstruct
    end_time: the ending time of the segment to reconstruct
    dt: the interval between two consecutive frames in the output
    Mx: order of dirichlet space in x variable
        must be the same as used in VTDM_prep
    
    Optional:
    My: order of dirichlet space in y variable
        must be the same as used in VTDM_prep
        if not specified, My = Mx
    Mt: order of dirichlet space in t variable
        if not specified, will infer from spikefile
    domain: domain: 4-list or 4-tuple, [xstart, xend, ystart, yend]
            the spatial domain to recover
            must be the same as in VTDM_prep
    Wx: bandwidth in x variable
        if not specified, will use the info in spikefile
    Wy: bandwidth in y variable
        if not specified, will use the info in spikefile
    Wt: bandwidth in t variable
        if not specified, will use the info in spikefile
    lamb: smoothing coefficient \lambda
    dtype: np.float128 or np.float64, data type of the outpur
           If not specified, will be set to np.float64
    stitching: True if the recovery algorithm should use stiching.
               False otherwise.
               If not specified, will be set to False.
               Stiching should ideally be used only for long videos.
    stitch_interval: If stitching is set to True, stitch_interval
                    will set the individual segment lengths.
                    If not specified, will default to an interval
                    corresponding to 20 frames. If a larger video is
                    being decoded, this should be explicitly set to a
                    lower value if a single GPU is being used.
    output_format: 0 to write recovered video to avi file
                   1 to write recovered video to h5 file
                   Anything else to not write recovered video to disc
                   If not specified, will be set to 0
    output: If output_format is 0 or 1, specifies the output filename
            If not specified, will be set to "rec"

    -------------------------------------------------------------------
    
    returns: The recovered video as a numpy array

    -------------------------------------------------------------------

    The coordinate system is given by the following:
        
        
        row (width / X) major
        -----------------------------------------> width
        |   Y
        |   ^-------------------->
        |   |-------------------->
        |   |-------------------->
        |   |--------------------> X
        | 
        |
        v
        height


    """

    spfile = ss.vtem_storage_read(spikefile)
    spfile.read_vrf_type()

    num_neurons = spfile.read_num_neurons()
    print "total neurons: %d" % (num_neurons)
    
    Wt1,Wx1,Wy1,Px,Py,dx,dy = spfile.read_video_attributes()

    if Wx is None:
        Wx = Wx1
    else:
        Wx = float(Wx)
    if Wy is None:
        Wy = Wy1
    else:
        Wy = float(Wy)
    if Wt is None:
        Wt = Wt1
    else:
        Wt = float(Wt)

    Mx = int(Mx)
    if My is None:
        My = Mx
    else:
        My = int(My)

    rfSx = (2*np.pi) * (Mx/Wx)
    rfSy = (2*np.pi) * (My/Wy)
    rfPx = int(np.ceil(rfSx/dx/16) * 16)
    rfPy = int(np.ceil(rfSy/dy/16) * 16)
    
    oriSx = Px*dx
    oriSy = Py*dy

    if domain is None:
        domain = [-oriSx/2,oriSx/2,-oriSy/2,oriSy/2]
    else:
        domain = [max(float(domain[0]),-oriSx/2), 
                  min(float(domain[1]), oriSx/2), 
                  max(float(domain[2]),-oriSy/2),
                  min(float(domain[3]),oriSy/2)]
        
    spfile.select_neurons(domain)
    print "select neurons: %d" % (spfile.decode_neurons)
    
    decSx = domain[1]-domain[0]
    decSy = domain[3]-domain[2]

    decPx = int(np.round(decSx/dx))
    decPy = int(np.round(decSy/dy))
    
    if Mt is None:
        Mt = int(round(Wt/np.pi))
    else:
        Mt = int(Mt)

    end = float(end_time)
    if stitching:
        if stitch_interval is None:
            stitch_interval = 20*dt
        stitch_interval = min(stitch_interval,end_time)
        end = round(float(min(start_time + stitch_interval, end_time))/dt)*dt
        overlap = 0.2*stitch_interval
    else:
        stitch_interval = float(end_time)
        overlap = 0

    start = float(start_time)
    
    s = read_file(dirichfilename).shape    
    total_l = round((end_time - start_time)/dt)
    overlap_l = round(overlap/dt)
    
    if overlap_l<3:
        if round(stitch_interval/dt)>=8:
            overlap_l = 4
            overlap = 4*dt
        else:
            overlap = 0
            overlap_l = 0

    interval_l = round(stitch_interval/dt)
    left_window = np.hanning(2*overlap_l)[0:overlap_l]
    left_window.shape = (left_window.shape[0], 1, 1)
    left_window = np.tile(left_window, (1, s[1], s[2]))
    right_window = np.ones(left_window.shape) - left_window

    u = np.empty([total_l, s[1], s[2]], dtype=dtype)
    
    i = 0
                          
    while True:
        tk1,tk2,neuron_ind,h_kappa,h_delta,h_bias,h_sigma = \
            spfile.read_select_spikes([start, end])

        h_norm = np.ones(spfile.num_neurons,np.float64)
        nonzerosigma = np.nonzero(h_sigma!=0)
        h_norm[nonzerosigma] = 1.0/(h_kappa[nonzerosigma]*h_sigma[nonzerosigma])
        h_norm = h_norm/h_norm.max()
        h_norm = h_norm[spfile.choose]
        
        cp = dr.dirichlet(spfile.decode_neurons, tk1,tk2,neuron_ind,h_kappa, 
                          h_delta, h_bias, h_norm, Wt, Mt, dtype)
        cp.compute_q()
        cp.compute_Gb(Dsfilename, float(lamb))

        if rnn:
            cp.d_q = nn.rnn3(cp.d_G, cp.d_q, alpha=alpha, steps=steps)
        else:
            cp.d_q = ls.solve_eq_sym(cp.d_G, cp.d_q)
        cp.freeG()
                        
        u_rec = cp.reconstruct(dirichfilename, [0, round((end-start)/dt)*dt],dt)
        u_rec = u_rec.get()
        
        
        if overlap_l:
            if i==0:
                if end_time-end<dt:
                    u = u_rec
                    break
                else:
                    u[0:(interval_l - overlap_l),:,:] = u_rec[0:-overlap_l,:,:]
                    prev = u_rec[-overlap_l:,:,:] 
            else:
                l_ind = i*(interval_l - overlap_l)
                r_ind = (i+1)*(interval_l - overlap_l)
                u[l_ind:(l_ind + overlap_l),:,:] = \
                    u_rec[0:overlap_l,:,:]*left_window + prev*right_window

                if end_time-end<dt:
                    u[(l_ind + overlap_l):,:,:] = \
                        u_rec[(-u.shape[0]+l_ind+overlap_l):,:,:]
                    break

                u[(l_ind + overlap_l):r_ind,:,:] = \
                    u_rec[overlap_l:interval_l-overlap_l,:,:]

                prev = u_rec[-overlap_l:,:,:]

        else:
            u[(i*interval_l):((i*interval_l) + u_rec.shape[0]),:,:] = u_rec
            if end_time-end<dt:
                break
        
        start = start + round(stitch_interval/dt)*dt - round(overlap/dt)*dt
        end = min(start + round(stitch_interval/dt)*dt, float(end_time))
        
        if (float(end_time) - end < overlap):
            end = end_time
    

        i = i + 1

    if output_format==0:
        vio.write_video(u, output)
    elif output_format==1:
        write_memory_to_file(u, output)
    return u
Exemple #3
0
def decode_video(spikefile, Dswfilename, dirichfilename, start_time,
                 end_time, dt, Mx, My=None, Mt=None, domain=None, Wx=None, 
                 Wy=None, Wt=None, lamb=0.0, dtype=np.float64, rnn=False, 
                 alpha=50000, steps=4000, stitching=False, 
                 stitch_interval=None, spatial_stitching=False, 
                 spatial_interval=[96,96], precompute=True, write_blocks=False,
                 output_format=0, output="rec"):
    """
    Reconstruct video using VTDM with dirichlet kernel, assuming IAF neurons
    
    Parameters
    ----------
    
    spikefile : string
        the file generated by VTEM containing spike info
    Dswfilename : string
        file generated by VTDM_prep
    dirichfilename : string 
        file generated by VTDM_prep
    start_time : float
        the starting time of the segment to reconstruct
    end_time : float 
        the ending time of the segment to reconstruct
    dt : float 
        the interval between two consecutive frames in the output
    Mx : integer
        order of dirichlet space in x variable
        must be the same as used in VTDM_prep
    My : integer, optional
        order of dirichlet space in y variable
        must be the same as used in VTDM_prep
        if not specified, My = Mx
    Mt : integer, optional
        order of dirichlet space in t variable
        if not specified, will infer from spikefile
    domain: list, optional 
        4-list or 4-tuple, [xstart, xend, ystart, yend]
        the spatial domain to recover
        must be the same as in VTDM_prep
    Wx : float, optional
        bandwidth in x variable
        if not specified, will use the info in spikefile
    Wy : float, optional 
        bandwidth in y variable
        if not specified, will use the info in spikefile
    Wt : float, optional 
        bandwidth in t variable
        if not specified, will use the info in spikefile
    lamb : float, optional 
        smoothing coefficient \lambda
    dtype : optional  
        np.float128 or np.float64, data type of the output
        If not specified, will be set to np.float64
    stitching: bool, optional
        True if the recovery algorithm should use stiching.
        False otherwise.
        If not specified, will be set to False.
        Stiching should ideally be used only for long videos.
    stitch_interval : integer, optional 
        If stitching is set to True, stitch_interval
        will set the individual segment lengths.
        If not specified, will default to an interval
        corresponding to 20 frames. If a larger video is
        being decoded, this should be explicitly set to a
        lower value if a single GPU is being used.
    spatial_stitiching : bool, optional
        True if the recovery algorithm should use
        spatial stitching. If the video is of high
        resolution, should be set to true.
        Spatial stiching will be diabled if the domain
        of recovery is less than 96x96 pixels.
        If not specified, will be set to False.
    spatial_interval : list, optional 
        If spatial_stiching is set to True, will determine
        the domain for spatial stiching.
        If not specified, will default to [96,96] pixels.
    precompute : bool, optional
        If set to True, innerproducts will be calculated and
        stored in an h5 file before the VTDM call.
        If not specified, will default to True
    write_blocks: bool, optional
        Only used if spatial_stitching is True.
        If set to True, will write the individual blocks to disc 
        as well.
        If not specified, will be set to False
    output_format : integer, optional 
        0 to write recovered video to avi file
        1 to write recovered video to h5 file
        Anything else to not write recovered video to disc
        If not specified, will be set to 0
    output : string, optional 
        output filename in which the reconstructed video will be stored
        If not specified, will be set to "rec" 
    
    Returns
    -------
    The recovered video as a numpy array

    Notes
    -----

    On a single device with 3 GB memory, around 27000 spikes can be decoded.
    Spatial stitching and temporal stitching can be used to recover a video
    which produces more spikes than what can be decoded on single GPU. 
    See Example below.
    
    The coordinate system is given by the following:
        
    ::
    
        row (width / X) major
    
        -----------------------------------------> width
        |   Y
        |   ^-------------------->
        |   |-------------------->
        |   |-------------------->
        |   |--------------------> X
        | 
        |
        v
        height

    Examples
    --------
    
    >>> import atexit
    >>> import pycuda.driver as cuda
    >>> import numpy as np
    >>> from vtem import vtem,vtdm
    >>> cuda.init()
    >>> context1 = cuda.Device(0).make_context()
    >>> atexit.register(cuda.Context.pop)
    >>> vtem.VTEM_Gabor_IAF('fly_large.avi', 'spikes.h5', 
                            2*np.pi*10, h5input=False)
    >>> vtdm.decode_video('spikes.h5', 'dsw.h5', 'dirich.h5',
                          0, 1, 0.01, Mx, rnn=True, alpha=5000, steps=4000,
                          dtype=np.float32, stitching=True, stitch_interval=0.2,
                          spatial_stitching = True, spatial_interval = [80,80],
                          output_format=0, write_blocks=True)
    
    """
    
    if not spatial_stitching:
        if precompute:
            VTDM_prep(spikefile, Dswfilename, dirichfilename, Mx, My, 
                      domain, Wx, Wy, dtype=np.complex128)
            VTDM(spikefile, Dswfilename, dirichfilename, start_time,
                 end_time, dt, Mx, My, Mt, domain, Wx, 
                 Wy, Wt, lamb, dtype, rnn, 
                 alpha, steps, stitching, 
                 stitch_interval,False,output_format, output)
        else:
            VTDM_prepb(spikefile, Dswfilename, dirichfilename, Mx, My, 
                      domain, Wx, Wy, dtype=np.complex128)
            VTDMb(spikefile, Dswfilename, dirichfilename, start_time,
                 end_time, dt, Mx, My, Mt, domain, Wx, 
                 Wy, Wt, lamb, dtype, rnn, 
                 alpha, steps, stitching, 
                 stitch_interval,False, output_format, output)

    else:
        spfile = ss.vtem_storage_read(spikefile)
        Wt,Wx1,Wy1,Px,Py,dx,dy = spfile.read_video_attributes()
        
        if Wx is None:
            Wx = Wx1
        else:
            Wx = float(Wx)
        if Wy is None:
            Wy = Wy1
        else:
            Wy = float(Wy)
               
        Mx = int(Mx)
        if My is None:
            My = Mx
        else:
            My = int(My)
       
        oriSx = Px*dx
        oriSy = Py*dy
    
        if domain is None:
            domain = [-oriSx/2,oriSx/2,-oriSy/2,oriSy/2]
        else:
            domain = [max(float(domain[0]),-oriSx/2), 
                      min(float(domain[1]), oriSx/2), 
                      max(float(domain[2]),-oriSy/2),
                      min(float(domain[3]), oriSy/2)]
            
        decSx = domain[1]-domain[0]
        decSy = domain[3]-domain[2]
       
        decPx = int(np.round(decSx / dx))
        decPy = int(np.round(decSy / dy))

        spatial_interval = [min(spatial_interval[0],decPx),
                            min(spatial_interval[1],decPy)]
        overlap = [round(.2*spatial_interval[0]),round(.2*spatial_interval[1])]
        start_y = max(domain[3] - spatial_interval[1]*dy, domain[2])
        end_y = domain[3] 
        
        total_l = round((end_time - start_time)/dt)
    
        u = np.zeros([total_l, decPy, decPx], dtype=dtype)
    
        start_y_i = 0
        end_y_i = spatial_interval[1]

        i=0
        
        if write_blocks:
            output_format_int = 1 if output_format==1 else 0
        else:
            output_format_int = 2


        while True:
            start_x = max(domain[1] - spatial_interval[0]*dx,domain[0])
            end_x = domain[1]
            start_x_i = decPx - spatial_interval[0]
            end_x_i = decPx
            while True:
                curr_domain = [start_x,end_x,start_y,end_y]
                if precompute:
                    VTDM_prep(spikefile, Dswfilename, dirichfilename, Mx, My, 
                              curr_domain, Wx, Wy, dtype=np.complex128)
                else:
                    VTDM_prepb(spikefile, Dswfilename, dirichfilename, Mx, My, 
                              curr_domain, Wx, Wy, dtype=np.complex128)
                    
                window_matrix =np.ones([end_y_i-start_y_i,end_x_i-start_x_i],
                                       dtype=dtype)
                left_size = 0 if start_x_i==0 else overlap[0]
                right_size = 0 if end_x_i==decPx else overlap[0]
                left_w = np.hanning(2*left_size)[:left_size]
                right_w = 1 - np.hanning(2*right_size)[:right_size]
                x_window  = np.concatenate((left_w, \
                                np.ones(end_x_i-start_x_i-left_size-right_size),
                                right_w))
                top_size = 0 if start_y_i==0 else overlap[1]
                bottom_size = 0 if end_y_i==decPy else overlap[1]
                top_w = np.hanning(2*top_size)[:top_size]
                bottom_w = 1 - np.hanning(2*bottom_size)[:bottom_size]
                y_window  = np.concatenate((top_w, \
                                np.ones(end_y_i-start_y_i-bottom_size-top_size),
                                bottom_w))
                y_window.shape = (y_window.shape[0], 1)
                x_window.shape = (1, x_window.shape[0])
                window_matrix = y_window.dot(x_window) 
                

                if precompute:
                    u[:,start_y_i:end_y_i,start_x_i:end_x_i] += np.multiply( \
                        VTDM(spikefile, Dswfilename, dirichfilename, 
                             start_time,end_time, dt, Mx, My, 
                             Mt, curr_domain, Wx, Wy, Wt, lamb, dtype, rnn, 
                             alpha, steps, stitching, 
                             stitch_interval, output_format_int,
                             output+str(i)),window_matrix)
                else:
                    u[:,start_y_i:end_y_i,start_x_i:end_x_i] += np.multiply( \
                        VTDMb(spikefile, Dswfilename, dirichfilename, 
                             start_time,end_time, dt, Mx, My, 
                             Mt, curr_domain, Wx, Wy, Wt, lamb, dtype, rnn, 
                             alpha, steps, stitching, 
                             stitch_interval, output_format_int,
                             output+str(i)),window_matrix)
                i+=1
                if start_x==domain[0]:
                    break
                start_x = max(start_x - spatial_interval[0]*dx + overlap[0]*dx,
                              domain[0])
                end_x = end_x - spatial_interval[0]*dx + overlap[0]*dx
                start_x_i = max(start_x_i - spatial_interval[0] + overlap[0],0)
                end_x_i = end_x_i - spatial_interval[0] + overlap[0]
                
            if start_y==domain[2]:
                break
            start_y = max(start_y - spatial_interval[1]*dy + overlap[1]*dy,
                          domain[2])
            end_y = end_y - spatial_interval[1]*dy + overlap[1]*dy
            start_y_i = end_y_i - overlap[1]
            end_y_i = min(start_y_i + spatial_interval[1],decPy)
            
        if output_format==0:
            vio.write_video(u, output+".avi")
        elif output_format==1:
            write_memory_to_file(u, output+".h5")
        
        return u