Esempio n. 1
0
    def get_shifted_templates(self, temp_ids, shifts, scales):
        
        temp_ids = torch.from_numpy(temp_ids.cpu().numpy()).long().cuda()
        shifts = torch.from_numpy(shifts.cpu().numpy()).float().cuda()
        scales = torch.from_numpy(scales.cpu().numpy()).float().cuda()

        n_sample_run = 1000
        n_times = self.templates_aligned.shape[1]

        idx_run = np.hstack((np.arange(0, len(shifts), n_sample_run), len(shifts)))

        shifted_templates = torch.cuda.FloatTensor(len(shifts), n_times, self.n_neigh_chans).fill_(0)
        for j in range(len(idx_run)-1):
            ii_start = idx_run[j]
            ii_end = idx_run[j+1]
            obj = torch.zeros(self.n_neigh_chans, (ii_end-ii_start)*n_times + 10).cuda()
            times = torch.arange(0, (ii_end-ii_start)*n_times, n_times).long().cuda() + 5
            deconv.subtract_splines(obj,
                                    times,
                                    shifts[ii_start:ii_end],
                                    temp_ids[ii_start:ii_end],
                                    self.coeffs, 
                                    scales[ii_start:ii_end])
            obj = obj[:, 5:-5].reshape((self.n_neigh_chans, (ii_end-ii_start), n_times))
            shifted_templates[ii_start:ii_end] = obj.transpose(0,1).transpose(1,2)
    
        return shifted_templates
Esempio n. 2
0
    def add_cpp_allspikes(self):
        #start = dt.datetime.now().timestamp()

        torch.cuda.synchronize()

        # select all spikes from a previous iteration
        spike_times, spike_temps, spike_shifts, spike_heights = self.sample_spikes_allspikes(
        )

        torch.cuda.synchronize()

        # also fill in self-convolution traces with low energy so the
        #   spikes cannot be detected again (i.e. enforcing refractoriness)
        # Cat: TODO: investgiate whether putting the refractoriness back in is viable
        if self.refractoriness:
            deconv.refrac_fill(
                energy=self.obj_gpu,
                spike_times=spike_times,
                spike_ids=spike_temps,
                fill_length=self.refractory * 2 +
                1,  # variable fill length here
                fill_offset=self.subtraction_offset - 2 - self.refractory,
                fill_value=self.fill_value)

        torch.cuda.synchronize()

        # Add spikes back in;
        deconv.subtract_splines(self.obj_gpu, spike_times, spike_shifts,
                                spike_temps, self.coefficients,
                                -self.tempScaling * spike_heights)

        torch.cuda.synchronize()

        return
Esempio n. 3
0
    def subtract_input_data(self, data, spike_train,
                            time_offsets, scales, output_numpy_format=False):

        # transfer raw data to cuda
        objective = np.copy(data)
        objective = torch.from_numpy(objective).cuda()

        time_indices = spike_train[:,0] - self.waveform_len//2
        time_indices = torch.from_numpy(time_indices).long().cuda()

        # select template ids
        template_ids = spike_train[:,1]
        template_ids = torch.from_numpy(template_ids).long().cuda()

        # select superres alignment shifts
        time_offsets = torch.from_numpy(time_offsets).float().cuda()

        # select superres alignment shifts
        scales = torch.from_numpy(scales).float().cuda()

        chunk_size = 10000
        for chunk in range(0, time_indices.shape[0], chunk_size):

            torch.cuda.synchronize()
            if time_indices[chunk:chunk+chunk_size].shape[0]==0:
                # Add spikes back in;
                deconv.subtract_splines(
                                    objective,
                                    time_indices[chunk:chunk+chunk_size][None],
                                    time_offsets[chunk:chunk+chunk_size],
                                    template_ids[chunk:chunk+chunk_size][None],
                                    self.coefficients,
                                    self.tempScaling*scales[chunk:chunk+chunk_size][None])


            else:
                deconv.subtract_splines(
                                    objective,
                                    time_indices[chunk:chunk+chunk_size],
                                    time_offsets[chunk:chunk+chunk_size],
                                    template_ids[chunk:chunk+chunk_size],
                                    self.coefficients,
                                    self.tempScaling*scales[chunk:chunk+chunk_size])

        if output_numpy_format:
            objective = objective.cpu().data.numpy()

        time_indices = None
        template_ids = None
        time_offsets = None
        scales = None

        return objective
Esempio n. 4
0
    def add_cpp_allspikes(self, idx_iter):
        #start = dt.datetime.now().timestamp()

        torch.cuda.synchronize()

        # select randomly 10% of spikes from previous deconv;
        #spike_times, spike_temps, spike_shifts, flag = self.sample_spikes(idx_iter)

        # select all spikes from a previous iteration
        spike_times, spike_temps, spike_shifts, flag = self.sample_spikes_allspikes(
            idx_iter)

        torch.cuda.synchronize()

        if flag == False:
            return

        # also fill in self-convolution traces with low energy so the
        #   spikes cannot be detected again (i.e. enforcing refractoriness)
        # Cat: TODO: investgiate whether putting the refractoriness back in is viable
        if self.refractoriness:
            deconv.refrac_fill(
                energy=self.obj_gpu,
                spike_times=spike_times,
                spike_ids=spike_temps,
                #fill_length=self.n_time,  # variable fill length here
                #fill_offset=self.n_time//2,       # again giving flexibility as to where you want the fill to start/end (when combined with preceeding arg
                fill_length=self.refractory * 2 +
                1,  # variable fill length here
                fill_offset=self.n_time // 2 + self.refractory //
                2,  # again giving flexibility as to where you want the fill to start/end (when combined with preceeding arg
                fill_value=self.fill_value)

            # deconv.subtract_spikes(data=self.obj_gpu,
            # spike_times=spike_times,
            # spike_temps=spike_temps,
            # templates=self.templates_cpp_refractory_add,
            # do_refrac_fill = False,
            # refrac_fill_val = -1e10)

        torch.cuda.synchronize()

        # Add spikes back in;
        deconv.subtract_splines(self.obj_gpu, spike_times, spike_shifts,
                                spike_temps, self.coefficients,
                                -self.tempScaling)

        torch.cuda.synchronize()

        return
Esempio n. 5
0
    def subtract_cpp(self):

        start = dt.datetime.now().timestamp()

        torch.cuda.synchronize()

        spike_times = self.spike_times.squeeze() - self.lockout_window
        spike_temps = self.neuron_ids.squeeze()

        # zero out shifts if superres shift turned off
        # Cat: TODO: remove this computation altogether if not required;
        #           will save some time.
        if self.superres_shift == False:
            self.xshifts = self.xshifts * 0

        # if single spike, wrap it in list
        # Cat: TODO make this faster/pythonic
        if self.spike_times.size()[0] == 1:
            spike_times = spike_times[None]
            spike_temps = spike_temps[None]

        deconv.subtract_splines(self.obj_gpu, spike_times, self.xshifts,
                                spike_temps, self.coefficients,
                                self.tempScaling)

        torch.cuda.synchronize()

        # also fill in self-convolution traces with low energy so the
        #   spikes cannot be detected again (i.e. enforcing refractoriness)
        # Cat: TODO: read from CONFIG

        if self.refractoriness:
            #print ("filling in timesteps: ", self.n_time)
            deconv.refrac_fill(
                energy=self.obj_gpu,
                spike_times=spike_times,
                spike_ids=spike_temps,
                #fill_length=self.n_time,  # variable fill length here
                #fill_offset=self.n_time//2,       # again giving flexibility as to where you want the fill to start/end (when combined with preceeding arg
                fill_length=self.refractory * 2 +
                1,  # variable fill length here
                fill_offset=self.n_time // 2 + self.refractory //
                2,  # again giving flexibility as to where you want the fill to start/end (when combined with preceeding arg
                fill_value=-self.fill_value)

        torch.cuda.synchronize()

        return (dt.datetime.now().timestamp() - start)
Esempio n. 6
0
    def get_shifted_templates(self, template_ids, time_offsets, scales, input_in_torch=True):

        if not input_in_torch:
            template_ids = torch.from_numpy(template_ids).long().cuda()
            time_offsets = torch.from_numpy(time_offsets).float().cuda()
            scales = torch.from_numpy(scales).float().cuda()

        obj = torch.cuda.FloatTensor(self.n_chan, len(template_ids)*self.waveform_len+10).fill_(0)
        times = torch.arange(0, len(template_ids)*self.waveform_len, self.waveform_len).long().cuda() + 5
        deconv.subtract_splines(obj,
                                times,
                                time_offsets,
                                template_ids,
                                self.coefficients, 
                                self.tempScaling*scales)
        obj = -obj[:,5:-5].reshape((self.n_chan, len(template_ids), self.waveform_len))
        
        del template_ids
        del time_offsets
        del times
        del scales
        torch.cuda.empty_cache()

        return obj.transpose(0,1).transpose(1,2)


        #n_sample_run = 5000
        #idx_run = np.hstack((np.arange(0, len(template_ids), n_sample_run), len(template_ids)))

        #shifted_templates = torch.zeros(
        #    (len(template_ids), self.waveform_len, self.n_chan)).float().cuda()
        #for j in range(len(idx_run)-1):
        #    ii_start = idx_run[j]
        #    ii_end = idx_run[j+1]
        #    obj = torch.zeros(self.n_chan, (ii_end-ii_start)*self.waveform_len).float().cuda()
        #    times = torch.arange(0, (ii_end-ii_start)*self.waveform_len, self.waveform_len).long().cuda()
        #    deconv.subtract_splines(obj,
        #                            times,
        #                            time_offsets[ii_start:ii_end],
        #                            template_ids[ii_start:ii_end],
        #                            self.coefficients, 
        #                            self.tempScaling*scales[ii_start:ii_end])
        #    obj = obj.reshape((self.n_chan, (ii_end-ii_start), self.waveform_len))
        #    shifted_templates[ii_start:ii_end] = obj.transpose(0,1).transpose(1,2)

        return -shifted_templates
Esempio n. 7
0
    def subtract_step(self):
        
        # loop over chunks and do work
        t0 = time.time()
        verbose = False
        debug = False
        
        residual_array = []
        self.reader.buffer = 200

        # open residual file for appending on the fly
        f = open(self.fname_residual,'wb')

        #self.chunk_id =0
        batch_ctr = 0
        batch_id = 0
        print (" STARTING RESIDUAL COMPUTATION...")
        for chunk in tqdm(self.reader.idx_list):
            
            time_sec = (batch_id*self.CONFIG.resources.n_sec_chunk_gpu_deconv)
                            
            #print ("time_sec: ", time_sec)
            
            # updated templates options
            if ((self.update_templates) and 
                (((time_sec)%self.template_update_time)==0) and
                (batch_id!=0)):
                
                #print ("UPDATING TEMPLATES, time_sec: ", time_sec)
                
                self.chunk_id +=1

                # Cat: TODO: note this function reads the average templates +60sec to
                #       correctly match what was computed during current window
                #       May wish to try other options
                self.fname_templates = os.path.join(os.path.split(self.data_dir)[0],
                        'deconv','template_updates',
                        'templates_'+str(time_sec+self.template_update_time)+'sec.npy')
                print ("updating templates from:  ", self.fname_templates)
                # 
               #print (" updating bsplines...")
                if False:
                    self.load_templates()
                    self.make_bsplines()
                else:
                    self.load_vis_units()
                    self.load_temp_temp()
                    self.initialize_cpp()
                    self.templates_to_bsplines()                    
           
            # load chunk starts and ends for indexing below
            chunk_start = chunk[0]
            chunk_end = chunk[1]
        
            # read and pad data with buffer
            # self.data_temp = self.reader.read_data_batch(batch_id, add_buffer=True)
            data_chunk = self.reader.read_data_batch(batch_id, add_buffer=True).T

            # transfer raw data to cuda
            objective = torch.from_numpy(data_chunk).cuda()
            if verbose: 
                print ("Input size: ",objective.shape, int(sys.getsizeof(objective)), "MB")

            # Cat: TODO: may wish to pre-compute spike indiexes in chunks using cpu-multiprocessing
            #            because this constant search is expensive;
            # index into spike train at chunks:
            # Cat: TODO: this may miss spikes that land exactly at time 61.
            idx = np.where(np.logical_and(self.spike_train[:,0]>=(chunk_start-self.waveform_len), 
                            self.spike_train[:,0]<=(chunk_end+self.waveform_len)))[0]
            if verbose: 
                print (" # idx of spikes in chunk ", idx.shape, idx)
            
            # offset time indices by added buffer above
            times_local = (self.spike_train[idx,0]+self.reader.buffer-chunk_start
                                                  -self.waveform_len//2)
            time_indices = torch.from_numpy(times_local).long().cuda()
            # spike_list.append(times_local+chunk_start)
            if verbose: 
                print ("spike times/time_indices: ", time_indices.shape, time_indices)

            # select template ids
            templates_local = self.spike_train[idx,1]
            template_ids = torch.from_numpy(templates_local).long().cuda()
            if verbose: 
                print (" template ids: ", template_ids.shape, template_ids)

            # select superres alignment shifts
            time_offsets_local = self.time_offsets[idx]
            time_offsets_local = torch.from_numpy(time_offsets_local).float().cuda()

            if verbose: 
                print ("time offsets: ", time_offsets_local.shape, time_offsets_local)
            
            if verbose:
                t5 = time.time()
                
            # of of spikes to be subtracted per iteration
            # Cat: TODO: read this from CONFIG;
            # Cat: TODO this may crash if a single spike is left; 
            #       needs to be wrapped in a list
            if True:
                chunk_size = 10000
                for chunk in range(0, time_indices.shape[0], chunk_size):
                    #print ("Chunk: ", chunk)
                    torch.cuda.synchronize()
                    if time_indices[chunk:chunk+chunk_size].shape[0]==0:
                        deconv.subtract_splines(
                                            objective,
                                            time_indices[chunk:chunk+chunk_size][None],
                                            time_offsets_local[chunk:chunk+chunk_size],
                                            template_ids[chunk:chunk+chunk_size][None],
                                            self.coefficients,
                                            self.tempScaling)
                    else:      
                        #print (" multi-spike subtraction: ", chunk)

                        deconv.subtract_splines(
                                            objective,
                                            time_indices[chunk:chunk+chunk_size],
                                            time_offsets_local[chunk:chunk+chunk_size],
                                            template_ids[chunk:chunk+chunk_size],
                                            self.coefficients,
                                            self.tempScaling)
                                            

                                            
            # do unit-wise subtraction; thread safe
            else:
                for unit in np.unique(template_ids.cpu().data.numpy()):
                    #print ('unit: ', unit)
                    torch.cuda.synchronize()
                    
                    idx_unit = np.where(template_ids.cpu().data.numpy()==unit)[0]
                    if idx_unit.shape[0]==1:
                        deconv.subtract_splines(
                                            objective,
                                            time_indices[idx_unit][None],
                                            time_offsets_local[idx_unit],
                                            template_ids[idx_unit][None],
                                            self.coefficients,
                                            self.tempScaling)
                    elif idx_unit.shape[0]>1:
                        deconv.subtract_splines(
                                            objective,
                                            time_indices[idx_unit],
                                            time_offsets_local[idx_unit],
                                            template_ids[idx_unit],
                                            self.coefficients,
                                            self.tempScaling)
                                            
            torch.cuda.synchronize()

            if verbose:
                print ("subtraction time: ", time.time()-t5)

            temp_out = objective[:,self.reader.buffer:-self.reader.buffer].cpu().data.numpy().copy(order='F')
            f.write(temp_out.T)
            
            batch_id+=1
            #if batch_id > 3:
            #    break
        f.close()

        print ("Total residual time: ", time.time()-t0)
            
Esempio n. 8
0
    def subtract_step_single_chunk(self):
        
        # loop over chunks and do work
        t0 = time.time()
        verbose = False
        debug = False
        
        # read and pad data with buffer
        # transfer raw data to cuda
        objective = self.data #.transpose(1,0)

        # offset time indices by added buffer above
        time_indices = self.spike_train[:,0]+1000

        # select template ids
        template_ids = self.spike_train[:,1] 

        # select superres alignment shifts
        time_offsets_local = self.time_offsets

        # dummy values
        tempScaling_array = time_offsets_local*0.0+1.0

        print ("objective: ", objective.shape)
        print ("time_indices: ", time_indices.shape)
        print ("template_ids: ", template_ids.shape)
        print ("time_offsets_local: ", time_offsets_local.shape)
        print ("tempScaling_array: ", tempScaling_array.shape)

        # of of spikes to be subtracted per iteration
        # Cat: TODO: read this from CONFIG;
        # Cat: TODO this may crash if a single spike is left; 
        #       needs to be wrapped in a list
        chunk_size = 10000
        for chunk in range(0, time_indices.shape[0], chunk_size):
            torch.cuda.synchronize()
            if time_indices[chunk:chunk+chunk_size].shape[0]==0:
                                                                                           
                # Add spikes back in;
                deconv.subtract_splines(
                                    objective,
                                    time_indices[chunk:chunk+chunk_size][None],
                                    time_offsets_local[chunk:chunk+chunk_size],
                                    template_ids[chunk:chunk+chunk_size][None],
                                    self.coefficients,
                                    #self.tempScaling
                                    tempScaling_array
                                    )
                                                                
            else:      
                deconv.subtract_splines(
                                    objective,
                                    time_indices[chunk:chunk+chunk_size],
                                    time_offsets_local[chunk:chunk+chunk_size],
                                    template_ids[chunk:chunk+chunk_size],
                                    self.coefficients,
                                    #self.tempScaling
                                    tempScaling_array
                                    )                                        
                                        
            torch.cuda.synchronize()

        return objective
Esempio n. 9
0
    def subtract_step(self):

        # loop over chunks and do work
        t0 = time.time()
        verbose = False
        debug = False

        #
        if self.update_templates:
            # at which chunk templates need to be updated
            n_chunks_update = int(self.template_update_time/self.reader.n_sec_chunk)
            update_chunk = np.arange(0, self.reader.n_batches, n_chunks_update)

        # open residual file for appending on the fly
        f = open(self.fname_residual, 'wb')
        for batch_id, chunk in tqdm(enumerate(self.reader.idx_list)):

            # updated templates options
            if self.update_templates and np.any(update_chunk == batch_id):

                time_sec_start = batch_id*self.reader.n_sec_chunk
                fname_templates = os.path.join(self.templates_dir,
                                               'templates_{}sec.npy'.format(
                                                   time_sec_start))
                self.load_templates(fname_templates)
                self.make_bsplines_parallel()

            # load chunk starts and ends for indexing below
            chunk_start = chunk[0]
            chunk_end = chunk[1]

            # read and pad data with buffer
            # self.data_temp = self.reader.read_data_batch(batch_id, add_buffer=True)
            data_chunk = self.reader.read_data_batch(batch_id, add_buffer=True).T

            # transfer raw data to cuda
            objective = torch.from_numpy(data_chunk).cuda()
            if verbose: 
                print ("Input size: ",objective.shape, int(sys.getsizeof(objective)), "MB")

            # Cat: TODO: may wish to pre-compute spike indiexes in chunks using cpu-multiprocessing
            #            because this constant search is expensive;
            # index into spike train at chunks:
            # Cat: TODO: this may miss spikes that land exactly at time 61.
            idx = np.where(np.logical_and(
                self.spike_train[:,0]>=(chunk_start-self.waveform_len),
                self.spike_train[:,0]<=(chunk_end+self.waveform_len)))[0]

            if verbose: 
                print (" # idx of spikes in chunk ", idx.shape, idx)
            
            # offset time indices by added buffer above
            times_local = (self.spike_train[idx,0]+self.reader.buffer-chunk_start
                                                  -self.waveform_len//2)
            time_indices = torch.from_numpy(times_local).long().cuda()
            # spike_list.append(times_local+chunk_start)
            if verbose: 
                print ("spike times: ", time_indices.shape, time_indices)

            # select template ids
            templates_local = self.spike_train[idx,1]
            template_ids = torch.from_numpy(templates_local).long().cuda()
            # id_list.append(templates_local)
            if verbose: 
                print (" template ids: ", template_ids.shape, template_ids)

            # select superres alignment shifts
            time_offsets_local = self.time_offsets[idx]
            time_offsets_local = torch.from_numpy(time_offsets_local).float().cuda()
            
            # select superres alignment shifts
            scales_local = self.scales[idx]
            scales_local = torch.from_numpy(scales_local).float().cuda()

            if verbose: 
                print ("time offsets: ", time_offsets_local.shape, time_offsets_local)

            if verbose:
                t5 = time.time()

            if False:
                np.save('/home/cat/times.npy', time_indices.cpu().data.numpy())
                np.save('/home/cat/objective.npy', objective.cpu().data.numpy())
                np.save('/home/cat/template_ids.npy', template_ids.cpu().data.numpy())
                np.save('/home/cat/time_offsets_local.npy', time_offsets_local.cpu().data.numpy())
                
            # of of spikes to be subtracted per iteration
            # Cat: TODO: read this from CONFIG;
            # Cat: TODO this may crash if a single spike is left; 
            #       needs to be wrapped in a list
            chunk_size = 10000
            for chunk in range(0, time_indices.shape[0], chunk_size):

                torch.cuda.synchronize()
                if time_indices[chunk:chunk+chunk_size].shape[0]==0:
                    # Add spikes back in;
                    deconv.subtract_splines(
                                        objective,
                                        time_indices[chunk:chunk+chunk_size][None],
                                        time_offsets_local[chunk:chunk+chunk_size],
                                        template_ids[chunk:chunk+chunk_size][None],
                                        self.coefficients,
                                        self.tempScaling*scales_local[chunk:chunk+chunk_size][None])
                                        
                            
                else:
                    deconv.subtract_splines(
                                        objective,
                                        time_indices[chunk:chunk+chunk_size],
                                        time_offsets_local[chunk:chunk+chunk_size],
                                        template_ids[chunk:chunk+chunk_size],
                                        self.coefficients,
                                        self.tempScaling*scales_local[chunk:chunk+chunk_size])
                                        
            torch.cuda.synchronize()

            if verbose:
                print ("subtraction time: ", time.time()-t5)

            temp_out = objective[:,self.reader.buffer:-self.reader.buffer].cpu().data.numpy().copy(order='F')
            f.write(temp_out.T)
            
            batch_id+=1
            #if batch_id > 3:
            #    break
        f.close()

        print ("Total residual time: ", time.time()-t0)
Esempio n. 10
0
    def subtract_cpp(self):

        start = dt.datetime.now().timestamp()

        torch.cuda.synchronize()

        if False:
            self.spike_times = self.spike_times[:1]
            self.neuron_ids = self.neuron_ids[:1]
            self.xshifts = self.xshifts[:1]
            self.heights = self.heights[:1]
            self.obj_gpu *= 0.

        #spike_times = self.spike_times.squeeze()-self.lockout_window
        spike_times = self.spike_times.squeeze() - self.subtraction_offset
        spike_temps = self.neuron_ids.squeeze()

        # zero out shifts if superres shift turned off
        # Cat: TODO: remove this computation altogether if not required;
        #           will save some time.
        if self.superres_shift == False:
            self.xshifts = self.xshifts * 0

        # if single spike, wrap it in list
        # Cat: TODO make this faster/pythonic

        if self.spike_times.size()[0] == 1:
            spike_times = spike_times[None]
            spike_temps = spike_temps[None]

        #print ("spke_times: ", spike_times, spike_times)
        #print ("spke_times: ", spike_times[:20], spike_times[-20:])

        # save metadata
        if False:
            if self.n_iter < 500:
                self.objectives_dir = os.path.join(self.out_dir, 'objectives')
                if not os.path.isdir(self.objectives_dir):
                    os.mkdir(self.objectives_dir)

                np.save(
                    self.out_dir + '/objectives/spike_times_inside_' +
                    str(self.chunk_id) + "_iter_" + str(self.n_iter) + '.npy',
                    spike_times.squeeze().cpu().data.numpy())
                np.save(
                    self.out_dir + '/objectives/spike_ids_inside_' +
                    str(self.chunk_id) + "_iter_" + str(self.n_iter) + '.npy',
                    spike_temps.squeeze().cpu().data.numpy())
                np.save(
                    self.out_dir + '/objectives/obj_gpu_' +
                    str(self.chunk_id) + "_iter_" + str(self.n_iter) + '.npy',
                    self.obj_gpu.cpu().data.numpy())
                np.save(
                    self.out_dir + '/objectives/shifts_' + str(self.chunk_id) +
                    "_iter_" + str(self.n_iter) + '.npy',
                    self.xshifts.cpu().data.numpy())
                np.save(
                    self.out_dir + '/objectives/tempScaling_' +
                    str(self.chunk_id) + "_iter_" + str(self.n_iter) + '.npy',
                    self.tempScaling)
                np.save(
                    self.out_dir + '/objectives/heights_' +
                    str(self.chunk_id) + "_iter_" + str(self.n_iter) + '.npy',
                    self.heights.cpu().data.numpy())

                if False:
                    for k in range(len(self.coefficients)):
                        np.save(
                            self.out_dir + '/objectives/coefficients_' +
                            str(k) + "_" + str(self.chunk_id) + "_iter_" +
                            str(self.n_iter) + '.npy',
                            self.coefficients[k].data.cpu().numpy())
                    print("spike_times: ", spike_times.shape)
                    print("spike_times: ", type(spike_times.data[0].item()))
                    print("spike_temps: ", spike_temps.shape)
                    print("spike_temps: ", type(spike_temps.data[0].item()))
                    print("self.obj_gpu: ", self.obj_gpu.shape)
                    print("self.obj_gpu: ",
                          type(self.obj_gpu.data[0][0].item()))
                    print("self.xshifts: ", self.xshifts.shape)
                    print("self.xshifts: ", type(self.xshifts.data[0].item()))
                    print("self.tempScaling: ", self.tempScaling)
                    print("self.heights: ", self.heights.shape)
                    print("self.heights: ", type(self.heights.data[0].item()))
                    print("self.coefficients[k]: ",
                          self.coefficients[k].data.shape)
                    print("self.coefficients[k]: ",
                          type(self.coefficients[k].data[0][0].item()))
            else:
                quit()

        #self.obj_gpu = self.obj_gpu*0.
        #spike_times = spike_times -99
        deconv.subtract_splines(self.obj_gpu, spike_times, self.xshifts,
                                spike_temps, self.coefficients,
                                self.tempScaling * self.heights)

        torch.cuda.synchronize()

        # also fill in self-convolution traces with low energy so the
        #   spikes cannot be detected again (i.e. enforcing refractoriness)
        # Cat: TODO: read from CONFIG

        if self.refractoriness:
            #print ("filling in timesteps: ", self.n_time)
            deconv.refrac_fill(
                energy=self.obj_gpu,
                spike_times=spike_times,
                spike_ids=spike_temps,
                fill_length=self.refractory * 2 +
                1,  # variable fill length here
                fill_offset=self.subtraction_offset - 2 - self.refractory,
                fill_value=-self.fill_value)

        torch.cuda.synchronize()

        return (dt.datetime.now().timestamp() - start)