def VTDM_prepb(spikefile, Dsfilename, dirichfilename, Mx, My = None, domain=None, Wx = None, Wy = None, dtype = np.complex128): """ Prepare decoding with two files Dswfilename and dirichfilename must be called before VTDM a new spikefile parameters: spikefile: the file generated by VTEM containing spike info Dsfilename: name of the resulting file, containing the spatial dirichlet coefficients for all RFs dirichfilename: the name of the resulting file, containing the spatial reconstruction function for all RFs Mx: order of dirichlet space in x variable My: order of dirichlet space in y variable if not specified, My = Mx domain: 4-list or 4-tuple, [xstart, xend, ystart, yend] the spatial domain to recover Wx: bandwidth in x variable if not specified, will use the info in spikefile Wy: bandwidth in y variable if not specified, will use the info in spikefile dtype: np.complex128 or np.complex64, accurancy of computing dirichlet coefficients and the output files will be in the real format derived from dtype The coordinate system is given by the following: row (width / X) major -----------------------------------------> width | Y | ^--------------------> | |--------------------> | |--------------------> | |--------------------> X | | v height VTDM_prep computes inner product between two RFs and store in Dswfilename. VTDM_prepb computes only the dirichlet coefficient of each RF and stores it in Dsfilename. The inner product is computed in VTDMb. """ if dtype not in [np.complex128, np.complex64]: raise TypeError("dtype must be complex128 or complex64") spfile = ss.vtem_storage_read(spikefile) spfile.read_vrf_type() num_neurons = spfile.read_num_neurons() print "total neurons: %d" % (num_neurons) Wt,Wx1,Wy1,Px,Py,dx,dy = spfile.read_video_attributes() if Wx is None: Wx = Wx1 else: Wx = float(Wx) if Wy is None: Wy = Wy1 else: Wy = float(Wy) Mx = int(Mx) if My is None: My = Mx else: My = int(My) rfSx = (2*np.pi) * (Mx/Wx) rfSy = (2*np.pi) * (My/Wy) rfPx = int(np.ceil(rfSx/dx/16) * 16) rfPy = int(np.ceil(rfSy/dy/16) * 16) oriSx = Px*dx oriSy = Py*dy if domain is None: domain = [-oriSx/2,oriSx/2,-oriSy/2,oriSy/2] else: domain = [max(float(domain[0]),-oriSx/2), min(float(domain[1]), oriSx/2), max(float(domain[2]),-oriSy/2), min(float(domain[3]),oriSy/2)] spfile.select_neurons(domain) print "select neurons: %d" % (spfile.decode_neurons) decSx = domain[1]-domain[0] decSy = domain[3]-domain[2] decPx = int(np.round(decSx / dx)) decPy = int(np.round(decSy / dy)) if spfile.filter_type == "gabor": gb = vrf.vrf_gabor((rfPy, rfPx), dx=rfSx/rfPx, dy=rfSy/rfPy, scale=4, dtype=dtype) gb.load_parameters(num_neurons=spfile.decode_neurons, h_alpha=spfile.alpha, h_l=spfile.l, h_x0=spfile.x0, h_y0=spfile.y0, h_ab=spfile.ab, KAPPA=spfile.KAPPA) else: gb = vrf.vrf_cs((rfPy, rfPx), dx=rfSx/rfPx, dy=rfSy/rfPy, scale=4, dtype=dtype) gb.load_parameters(num_neurons=spfile.decode_neurons, h_alpha=spfile.alpha, h_x0=spfile.x0, h_y0=spfile.y0, sigma_center=spfile.sigma_center, sigma_surround=spfile.sigma_surround) d_Ds = gb.compute_Ds(Mx, My) write_memory_to_file(d_Ds, Dsfilename) d_dirich = gb.compute_dirich_space_fft(d_Ds, Mx, My, decPx, decPy, decSx, decSy, Wx, Wy) write_memory_to_file(d_dirich, dirichfilename) del d_dirich spfile.close()
def VTDMb(spikefile, Dsfilename, dirichfilename, start_time, end_time, dt, Mx, My=None, Mt=None, domain=None, Wx=None, Wy=None, Wt=None, lamb=0.0, dtype=np.float64, rnn=False, alpha=50000,steps=4000, stitching=False, stitch_interval=None, output_format=0, output="rec"): """ Reconstruct video using VTDM with dirichlet kernel, assuming IAF neurons must call VTDM_prepb before using this for differernt Mx,My for each spikefile ------------------------------------------------------------------- parameters: Required: spikefile:the file generated by VTEM containing spike info Dswfilename: file generated by VTDM_prep dirichfilename: file generated by VTDM_prep start_time: the starting time of the segment to reconstruct end_time: the ending time of the segment to reconstruct dt: the interval between two consecutive frames in the output Mx: order of dirichlet space in x variable must be the same as used in VTDM_prep Optional: My: order of dirichlet space in y variable must be the same as used in VTDM_prep if not specified, My = Mx Mt: order of dirichlet space in t variable if not specified, will infer from spikefile domain: domain: 4-list or 4-tuple, [xstart, xend, ystart, yend] the spatial domain to recover must be the same as in VTDM_prep Wx: bandwidth in x variable if not specified, will use the info in spikefile Wy: bandwidth in y variable if not specified, will use the info in spikefile Wt: bandwidth in t variable if not specified, will use the info in spikefile lamb: smoothing coefficient \lambda dtype: np.float128 or np.float64, data type of the outpur If not specified, will be set to np.float64 stitching: True if the recovery algorithm should use stiching. False otherwise. If not specified, will be set to False. Stiching should ideally be used only for long videos. stitch_interval: If stitching is set to True, stitch_interval will set the individual segment lengths. If not specified, will default to an interval corresponding to 20 frames. If a larger video is being decoded, this should be explicitly set to a lower value if a single GPU is being used. output_format: 0 to write recovered video to avi file 1 to write recovered video to h5 file Anything else to not write recovered video to disc If not specified, will be set to 0 output: If output_format is 0 or 1, specifies the output filename If not specified, will be set to "rec" ------------------------------------------------------------------- returns: The recovered video as a numpy array ------------------------------------------------------------------- The coordinate system is given by the following: row (width / X) major -----------------------------------------> width | Y | ^--------------------> | |--------------------> | |--------------------> | |--------------------> X | | v height """ spfile = ss.vtem_storage_read(spikefile) spfile.read_vrf_type() num_neurons = spfile.read_num_neurons() print "total neurons: %d" % (num_neurons) Wt1,Wx1,Wy1,Px,Py,dx,dy = spfile.read_video_attributes() if Wx is None: Wx = Wx1 else: Wx = float(Wx) if Wy is None: Wy = Wy1 else: Wy = float(Wy) if Wt is None: Wt = Wt1 else: Wt = float(Wt) Mx = int(Mx) if My is None: My = Mx else: My = int(My) rfSx = (2*np.pi) * (Mx/Wx) rfSy = (2*np.pi) * (My/Wy) rfPx = int(np.ceil(rfSx/dx/16) * 16) rfPy = int(np.ceil(rfSy/dy/16) * 16) oriSx = Px*dx oriSy = Py*dy if domain is None: domain = [-oriSx/2,oriSx/2,-oriSy/2,oriSy/2] else: domain = [max(float(domain[0]),-oriSx/2), min(float(domain[1]), oriSx/2), max(float(domain[2]),-oriSy/2), min(float(domain[3]),oriSy/2)] spfile.select_neurons(domain) print "select neurons: %d" % (spfile.decode_neurons) decSx = domain[1]-domain[0] decSy = domain[3]-domain[2] decPx = int(np.round(decSx/dx)) decPy = int(np.round(decSy/dy)) if Mt is None: Mt = int(round(Wt/np.pi)) else: Mt = int(Mt) end = float(end_time) if stitching: if stitch_interval is None: stitch_interval = 20*dt stitch_interval = min(stitch_interval,end_time) end = round(float(min(start_time + stitch_interval, end_time))/dt)*dt overlap = 0.2*stitch_interval else: stitch_interval = float(end_time) overlap = 0 start = float(start_time) s = read_file(dirichfilename).shape total_l = round((end_time - start_time)/dt) overlap_l = round(overlap/dt) if overlap_l<3: if round(stitch_interval/dt)>=8: overlap_l = 4 overlap = 4*dt else: overlap = 0 overlap_l = 0 interval_l = round(stitch_interval/dt) left_window = np.hanning(2*overlap_l)[0:overlap_l] left_window.shape = (left_window.shape[0], 1, 1) left_window = np.tile(left_window, (1, s[1], s[2])) right_window = np.ones(left_window.shape) - left_window u = np.empty([total_l, s[1], s[2]], dtype=dtype) i = 0 while True: tk1,tk2,neuron_ind,h_kappa,h_delta,h_bias,h_sigma = \ spfile.read_select_spikes([start, end]) h_norm = np.ones(spfile.num_neurons,np.float64) nonzerosigma = np.nonzero(h_sigma!=0) h_norm[nonzerosigma] = 1.0/(h_kappa[nonzerosigma]*h_sigma[nonzerosigma]) h_norm = h_norm/h_norm.max() h_norm = h_norm[spfile.choose] cp = dr.dirichlet(spfile.decode_neurons, tk1,tk2,neuron_ind,h_kappa, h_delta, h_bias, h_norm, Wt, Mt, dtype) cp.compute_q() cp.compute_Gb(Dsfilename, float(lamb)) if rnn: cp.d_q = nn.rnn3(cp.d_G, cp.d_q, alpha=alpha, steps=steps) else: cp.d_q = ls.solve_eq_sym(cp.d_G, cp.d_q) cp.freeG() u_rec = cp.reconstruct(dirichfilename, [0, round((end-start)/dt)*dt],dt) u_rec = u_rec.get() if overlap_l: if i==0: if end_time-end<dt: u = u_rec break else: u[0:(interval_l - overlap_l),:,:] = u_rec[0:-overlap_l,:,:] prev = u_rec[-overlap_l:,:,:] else: l_ind = i*(interval_l - overlap_l) r_ind = (i+1)*(interval_l - overlap_l) u[l_ind:(l_ind + overlap_l),:,:] = \ u_rec[0:overlap_l,:,:]*left_window + prev*right_window if end_time-end<dt: u[(l_ind + overlap_l):,:,:] = \ u_rec[(-u.shape[0]+l_ind+overlap_l):,:,:] break u[(l_ind + overlap_l):r_ind,:,:] = \ u_rec[overlap_l:interval_l-overlap_l,:,:] prev = u_rec[-overlap_l:,:,:] else: u[(i*interval_l):((i*interval_l) + u_rec.shape[0]),:,:] = u_rec if end_time-end<dt: break start = start + round(stitch_interval/dt)*dt - round(overlap/dt)*dt end = min(start + round(stitch_interval/dt)*dt, float(end_time)) if (float(end_time) - end < overlap): end = end_time i = i + 1 if output_format==0: vio.write_video(u, output) elif output_format==1: write_memory_to_file(u, output) return u
def decode_video(spikefile, Dswfilename, dirichfilename, start_time, end_time, dt, Mx, My=None, Mt=None, domain=None, Wx=None, Wy=None, Wt=None, lamb=0.0, dtype=np.float64, rnn=False, alpha=50000, steps=4000, stitching=False, stitch_interval=None, spatial_stitching=False, spatial_interval=[96,96], precompute=True, write_blocks=False, output_format=0, output="rec"): """ Reconstruct video using VTDM with dirichlet kernel, assuming IAF neurons Parameters ---------- spikefile : string the file generated by VTEM containing spike info Dswfilename : string file generated by VTDM_prep dirichfilename : string file generated by VTDM_prep start_time : float the starting time of the segment to reconstruct end_time : float the ending time of the segment to reconstruct dt : float the interval between two consecutive frames in the output Mx : integer order of dirichlet space in x variable must be the same as used in VTDM_prep My : integer, optional order of dirichlet space in y variable must be the same as used in VTDM_prep if not specified, My = Mx Mt : integer, optional order of dirichlet space in t variable if not specified, will infer from spikefile domain: list, optional 4-list or 4-tuple, [xstart, xend, ystart, yend] the spatial domain to recover must be the same as in VTDM_prep Wx : float, optional bandwidth in x variable if not specified, will use the info in spikefile Wy : float, optional bandwidth in y variable if not specified, will use the info in spikefile Wt : float, optional bandwidth in t variable if not specified, will use the info in spikefile lamb : float, optional smoothing coefficient \lambda dtype : optional np.float128 or np.float64, data type of the output If not specified, will be set to np.float64 stitching: bool, optional True if the recovery algorithm should use stiching. False otherwise. If not specified, will be set to False. Stiching should ideally be used only for long videos. stitch_interval : integer, optional If stitching is set to True, stitch_interval will set the individual segment lengths. If not specified, will default to an interval corresponding to 20 frames. If a larger video is being decoded, this should be explicitly set to a lower value if a single GPU is being used. spatial_stitiching : bool, optional True if the recovery algorithm should use spatial stitching. If the video is of high resolution, should be set to true. Spatial stiching will be diabled if the domain of recovery is less than 96x96 pixels. If not specified, will be set to False. spatial_interval : list, optional If spatial_stiching is set to True, will determine the domain for spatial stiching. If not specified, will default to [96,96] pixels. precompute : bool, optional If set to True, innerproducts will be calculated and stored in an h5 file before the VTDM call. If not specified, will default to True write_blocks: bool, optional Only used if spatial_stitching is True. If set to True, will write the individual blocks to disc as well. If not specified, will be set to False output_format : integer, optional 0 to write recovered video to avi file 1 to write recovered video to h5 file Anything else to not write recovered video to disc If not specified, will be set to 0 output : string, optional output filename in which the reconstructed video will be stored If not specified, will be set to "rec" Returns ------- The recovered video as a numpy array Notes ----- On a single device with 3 GB memory, around 27000 spikes can be decoded. Spatial stitching and temporal stitching can be used to recover a video which produces more spikes than what can be decoded on single GPU. See Example below. The coordinate system is given by the following: :: row (width / X) major -----------------------------------------> width | Y | ^--------------------> | |--------------------> | |--------------------> | |--------------------> X | | v height Examples -------- >>> import atexit >>> import pycuda.driver as cuda >>> import numpy as np >>> from vtem import vtem,vtdm >>> cuda.init() >>> context1 = cuda.Device(0).make_context() >>> atexit.register(cuda.Context.pop) >>> vtem.VTEM_Gabor_IAF('fly_large.avi', 'spikes.h5', 2*np.pi*10, h5input=False) >>> vtdm.decode_video('spikes.h5', 'dsw.h5', 'dirich.h5', 0, 1, 0.01, Mx, rnn=True, alpha=5000, steps=4000, dtype=np.float32, stitching=True, stitch_interval=0.2, spatial_stitching = True, spatial_interval = [80,80], output_format=0, write_blocks=True) """ if not spatial_stitching: if precompute: VTDM_prep(spikefile, Dswfilename, dirichfilename, Mx, My, domain, Wx, Wy, dtype=np.complex128) VTDM(spikefile, Dswfilename, dirichfilename, start_time, end_time, dt, Mx, My, Mt, domain, Wx, Wy, Wt, lamb, dtype, rnn, alpha, steps, stitching, stitch_interval,False,output_format, output) else: VTDM_prepb(spikefile, Dswfilename, dirichfilename, Mx, My, domain, Wx, Wy, dtype=np.complex128) VTDMb(spikefile, Dswfilename, dirichfilename, start_time, end_time, dt, Mx, My, Mt, domain, Wx, Wy, Wt, lamb, dtype, rnn, alpha, steps, stitching, stitch_interval,False, output_format, output) else: spfile = ss.vtem_storage_read(spikefile) Wt,Wx1,Wy1,Px,Py,dx,dy = spfile.read_video_attributes() if Wx is None: Wx = Wx1 else: Wx = float(Wx) if Wy is None: Wy = Wy1 else: Wy = float(Wy) Mx = int(Mx) if My is None: My = Mx else: My = int(My) oriSx = Px*dx oriSy = Py*dy if domain is None: domain = [-oriSx/2,oriSx/2,-oriSy/2,oriSy/2] else: domain = [max(float(domain[0]),-oriSx/2), min(float(domain[1]), oriSx/2), max(float(domain[2]),-oriSy/2), min(float(domain[3]), oriSy/2)] decSx = domain[1]-domain[0] decSy = domain[3]-domain[2] decPx = int(np.round(decSx / dx)) decPy = int(np.round(decSy / dy)) spatial_interval = [min(spatial_interval[0],decPx), min(spatial_interval[1],decPy)] overlap = [round(.2*spatial_interval[0]),round(.2*spatial_interval[1])] start_y = max(domain[3] - spatial_interval[1]*dy, domain[2]) end_y = domain[3] total_l = round((end_time - start_time)/dt) u = np.zeros([total_l, decPy, decPx], dtype=dtype) start_y_i = 0 end_y_i = spatial_interval[1] i=0 if write_blocks: output_format_int = 1 if output_format==1 else 0 else: output_format_int = 2 while True: start_x = max(domain[1] - spatial_interval[0]*dx,domain[0]) end_x = domain[1] start_x_i = decPx - spatial_interval[0] end_x_i = decPx while True: curr_domain = [start_x,end_x,start_y,end_y] if precompute: VTDM_prep(spikefile, Dswfilename, dirichfilename, Mx, My, curr_domain, Wx, Wy, dtype=np.complex128) else: VTDM_prepb(spikefile, Dswfilename, dirichfilename, Mx, My, curr_domain, Wx, Wy, dtype=np.complex128) window_matrix =np.ones([end_y_i-start_y_i,end_x_i-start_x_i], dtype=dtype) left_size = 0 if start_x_i==0 else overlap[0] right_size = 0 if end_x_i==decPx else overlap[0] left_w = np.hanning(2*left_size)[:left_size] right_w = 1 - np.hanning(2*right_size)[:right_size] x_window = np.concatenate((left_w, \ np.ones(end_x_i-start_x_i-left_size-right_size), right_w)) top_size = 0 if start_y_i==0 else overlap[1] bottom_size = 0 if end_y_i==decPy else overlap[1] top_w = np.hanning(2*top_size)[:top_size] bottom_w = 1 - np.hanning(2*bottom_size)[:bottom_size] y_window = np.concatenate((top_w, \ np.ones(end_y_i-start_y_i-bottom_size-top_size), bottom_w)) y_window.shape = (y_window.shape[0], 1) x_window.shape = (1, x_window.shape[0]) window_matrix = y_window.dot(x_window) if precompute: u[:,start_y_i:end_y_i,start_x_i:end_x_i] += np.multiply( \ VTDM(spikefile, Dswfilename, dirichfilename, start_time,end_time, dt, Mx, My, Mt, curr_domain, Wx, Wy, Wt, lamb, dtype, rnn, alpha, steps, stitching, stitch_interval, output_format_int, output+str(i)),window_matrix) else: u[:,start_y_i:end_y_i,start_x_i:end_x_i] += np.multiply( \ VTDMb(spikefile, Dswfilename, dirichfilename, start_time,end_time, dt, Mx, My, Mt, curr_domain, Wx, Wy, Wt, lamb, dtype, rnn, alpha, steps, stitching, stitch_interval, output_format_int, output+str(i)),window_matrix) i+=1 if start_x==domain[0]: break start_x = max(start_x - spatial_interval[0]*dx + overlap[0]*dx, domain[0]) end_x = end_x - spatial_interval[0]*dx + overlap[0]*dx start_x_i = max(start_x_i - spatial_interval[0] + overlap[0],0) end_x_i = end_x_i - spatial_interval[0] + overlap[0] if start_y==domain[2]: break start_y = max(start_y - spatial_interval[1]*dy + overlap[1]*dy, domain[2]) end_y = end_y - spatial_interval[1]*dy + overlap[1]*dy start_y_i = end_y_i - overlap[1] end_y_i = min(start_y_i + spatial_interval[1],decPy) if output_format==0: vio.write_video(u, output+".avi") elif output_format==1: write_memory_to_file(u, output+".h5") return u