def extract_spikes_alt(roiattrs):

    """ Infer approximate spike rates """
    print "\nrunning spike extraction"
    import c2s

    frameRate = 25
    if 'corr_traces' in roiattrs.keys():
        trace_type = 'corr_traces'
    else:
        trace_type = 'traces'
    data = [{'calcium':np.array([i]),'fps': frameRate} for i in roiattrs[trace_type]]
    spkt = c2s.predict(c2s.preprocess(data),verbosity=0)

    nROIs = len(roiattrs['idxs'])
    cFrames = np.array(roiattrs['traces']).shape[1]

    spk_traces = np.zeros([nROIs,cFrames])
    spk_long = []
    for i in range(nROIs):
        spk_traces[i] = np.mean(spkt[i]['predictions'].reshape(-1,4),axis=1)
        spk_long.append(spkt[i]['predictions'])

    roiattrs['spike_inf'] = spk_traces
    roiattrs['spike_long'] = np.squeeze(np.array(spk_long))
    return roiattrs
Esempio n. 2
0
    def spike_traces(self, X, fps):
        try:
            import c2s
        except:
            warn(
                "c2s was not found. You won't be able to populate ExtracSpikes"
            )
        assert self.fetch1[
            'language'] == 'python', "This tuple cannot be computed in python."
        if self.fetch1['spike_method'] == 3:
            N = len(X)
            for i, trace in enumerate(X):
                print('Predicting trace %i/%i' % (i + 1, N))
                tr0 = np.array(trace.pop('trace').squeeze())
                start = notnan(tr0)
                end = notnan(tr0, len(tr0) - 1, increment=-1)
                trace['calcium'] = np.atleast_2d(tr0[start:end + 1])

                trace['fps'] = fps
                data = c2s.preprocess([trace], fps=fps)
                data = c2s.predict(data, verbosity=0)

                tr0[start:end + 1] = data[0].pop('predictions')
                data[0]['rate_trace'] = tr0.T
                data[0].pop('calcium')
                data[0].pop('fps')

                yield data[0]
Esempio n. 3
0
def main(argv):
    parser = ArgumentParser(argv[0], description=__doc__)
    parser.add_argument('dataset', type=str)
    parser.add_argument('output', type=str, nargs='+')
    parser.add_argument('--model', '-m', type=str, default='')
    parser.add_argument(
        '--preprocess',
        '-p',
        type=int,
        default=0,
        help=
        'If you haven\'t already applied `preprocess` to the data, set to 1 (default: 0).'
    )
    parser.add_argument('--verbosity', '-v', type=int, default=1)

    args = parser.parse_args(argv[1:])

    experiment = Experiment()

    # load data
    data = load_data(args.dataset)

    if args.preprocess:
        # preprocess data
        data = preprocess(data, args.verbosity)

    if args.model:
        # load training results
        results = Experiment(args.model)['models']
    else:
        # use default model
        results = None

    # predict firing rates
    data = predict(data, results, verbosity=args.verbosity)

    # remove data except predictions
    for entry in data:
        if 'spikes' in entry:
            del entry['spikes']
        if 'spike_times' in entry:
            del entry['spike_times']
        del entry['calcium']

    for filepath in args.output:
        if filepath.lower().endswith('.mat'):
            # store in MATLAB format
            savemat(filepath, {'data': data})
        else:
            with open(filepath, 'w') as handle:
                dump(data, handle, protocol=2)

    return 0
Esempio n. 4
0
def main(argv):
	parser = ArgumentParser(argv[0], description=__doc__)
	parser.add_argument('dataset',            type=str)
	parser.add_argument('output',             type=str, nargs='+')
	parser.add_argument('--model',      '-m', type=str, default='')
	parser.add_argument('--preprocess', '-p', type=int, default=0,
		help='If you haven\'t already applied `preprocess` to the data, set to 1 (default: 0).')
	parser.add_argument('--verbosity',  '-v', type=int, default=1)

	args = parser.parse_args(argv[1:])

	experiment = Experiment()

	# load data
	data = load_data(args.dataset)

	if args.preprocess:
		# preprocess data
		data = preprocess(data, args.verbosity)

	if args.model:
		# load training results
		results = Experiment(args.model)['models']
	else:
		# use default model
		results = None

	# predict firing rates
	data = predict(data, results, verbosity=args.verbosity)

	# remove data except predictions
	for entry in data:
		if 'spikes' in entry:
			del entry['spikes']
		if 'spike_times' in entry:
			del entry['spike_times']
		del entry['calcium']

	for filepath in args.output:
		if filepath.lower().endswith('.mat'):
			# store in MATLAB format
			savemat(filepath, {'data': data})
		else:
			with open(filepath, 'w') as handle:
				dump(data, handle, protocol=2)

	return 0
Esempio n. 5
0
    def infer_spikes(self, X, dt, trace_name='ca_trace'):
        assert self.fetch1['language'] == 'python', "This tuple cannot be computed in python."
        fps = 1 / dt
        spike_rates = []
        N = len(X)
        for i, trace in enumerate(X):
            print('Predicting trace %i/%i' % (i+1,N))
            trace['calcium'] = trace.pop(trace_name).T
            trace['fps'] = fps

            data = c2s.preprocess([trace], fps=fps)
            data = c2s.predict(data, verbosity=0)
            data[0]['spike_trace'] = data[0].pop('predictions').T
            data[0].pop('calcium')
            data[0].pop('fps')
            spike_rates.append(data[0])
        return spike_rates
Esempio n. 6
0
    def infer_spikes(self, X, dt, trace_name='ca_trace'):
        assert self.fetch1[
            'language'] == 'python', "This tuple cannot be computed in python."
        fps = 1 / dt
        spike_rates = []
        N = len(X)
        for i, trace in enumerate(X):
            print('Predicting trace %i/%i' % (i + 1, N))
            trace['calcium'] = trace.pop(trace_name).T
            trace['fps'] = fps

            data = c2s.preprocess([trace], fps=fps)
            data = c2s.predict(data, verbosity=0)
            data[0]['spike_trace'] = data[0].pop('predictions').T
            data[0].pop('calcium')
            data[0].pop('fps')
            spike_rates.append(data[0])
        return spike_rates
Esempio n. 7
0
def main(argv):
	parser = ArgumentParser(argv[0], description=__doc__)
	parser.add_argument('dataset',                type=str,   nargs='+')
	parser.add_argument('output',                 type=str)
	parser.add_argument('--num_components', '-c', type=int,   default=3)
	parser.add_argument('--num_features',   '-f', type=int,   default=2)
	parser.add_argument('--num_models',     '-m', type=int,   default=4)
	parser.add_argument('--keep_all',       '-k', type=int,   default=1)
	parser.add_argument('--finetune',       '-n', type=int,   default=0)
	parser.add_argument('--num_valid',      '-s', type=int,   default=0)
	parser.add_argument('--var_explained',  '-e', type=float, default=95.)
	parser.add_argument('--window_length',  '-w', type=float, default=1000.)
	parser.add_argument('--regularize',     '-r', type=float, default=0.)
	parser.add_argument('--preprocess',     '-p', type=int,   default=0)
	parser.add_argument('--verbosity',      '-v', type=int,   default=1)

	args, _ = parser.parse_known_args(argv[1:])

	experiment = Experiment()

	# load data
	data = []
	for dataset in args.dataset:
		data = data + load_data(dataset)

	# preprocess data
	if args.preprocess:
		data = preprocess(data)

	# list of all cells
	if 'cell_num' in data[0]:
		# several trials/entries may belong to the same cell
		cells = unique([entry['cell_num'] for entry in data])
	else:
		# one cell corresponds to one trial/entry
		cells = range(len(data))
		for i in cells:
			data[i]['cell_num'] = i

	for i in cells:
		data_train = [entry for entry in data if entry['cell_num'] != i]
		data_test = [entry for entry in data if entry['cell_num'] == i]

		if args.verbosity > 0:
			print 'Test cell: {0}'.format(i)

		# train on all cells but cell i
		results = train(
			data=data_train,
			num_valid=args.num_valid,
			num_models=args.num_models,
			var_explained=args.var_explained,
			window_length=args.window_length,
			keep_all=args.keep_all,
			finetune=args.finetune,
			model_parameters={
					'num_components': args.num_components,
					'num_features': args.num_features},
			training_parameters={
				'verbosity': 0},
			regularize=args.regularize,
			verbosity=1)

		if args.verbosity > 0:
			print 'Predicting...'

		# predict responses of cell i
		predictions = predict(data_test, results, verbosity=0)

		for entry1, entry2 in zip(data_test, predictions):
			entry1['predictions'] = entry2['predictions']

	# remove data except predictions
	for entry in data:
		if 'spikes' in entry:
			del entry['spikes']
		if 'spike_times' in entry:
			del entry['spike_times']
		del entry['calcium']

	# save results
	if args.output.lower().endswith('.mat'):
		savemat(args.output, convert({'data': data}))

	elif args.output.lower().endswith('.xpck'):
		experiment['args'] = args
		experiment['data'] = data
		experiment.save(args.output)

	else:
		with open(args.output, 'w') as handle:
			dump(data, handle, protocol=2)

	return 0
Esempio n. 8
0
def main(argv):
    parser = ArgumentParser(argv[0], description=__doc__)
    parser.add_argument('dataset', type=str, nargs='+')
    parser.add_argument('output', type=str)
    parser.add_argument('--num_components', '-c', type=int, default=3)
    parser.add_argument('--num_features', '-f', type=int, default=2)
    parser.add_argument('--num_models', '-m', type=int, default=4)
    parser.add_argument('--keep_all', '-k', type=int, default=1)
    parser.add_argument('--finetune', '-n', type=int, default=0)
    parser.add_argument('--num_valid', '-s', type=int, default=0)
    parser.add_argument('--var_explained', '-e', type=float, default=95.)
    parser.add_argument('--window_length', '-w', type=float, default=1000.)
    parser.add_argument('--regularize', '-r', type=float, default=0.)
    parser.add_argument('--preprocess', '-p', type=int, default=0)
    parser.add_argument('--verbosity', '-v', type=int, default=1)

    args, _ = parser.parse_known_args(argv[1:])

    experiment = Experiment()

    # load data
    data = []
    for dataset in args.dataset:
        data = data + load_data(dataset)

    # preprocess data
    if args.preprocess:
        data = preprocess(data)

    # list of all cells
    if 'cell_num' in data[0]:
        # several trials/entries may belong to the same cell
        cells = unique([entry['cell_num'] for entry in data])
    else:
        # one cell corresponds to one trial/entry
        cells = range(len(data))
        for i in cells:
            data[i]['cell_num'] = i

    for i in cells:
        data_train = [entry for entry in data if entry['cell_num'] != i]
        data_test = [entry for entry in data if entry['cell_num'] == i]

        if args.verbosity > 0:
            print 'Test cell: {0}'.format(i)

        # train on all cells but cell i
        results = train(data=data_train,
                        num_valid=args.num_valid,
                        num_models=args.num_models,
                        var_explained=args.var_explained,
                        window_length=args.window_length,
                        keep_all=args.keep_all,
                        finetune=args.finetune,
                        model_parameters={
                            'num_components': args.num_components,
                            'num_features': args.num_features
                        },
                        training_parameters={'verbosity': 0},
                        regularize=args.regularize,
                        verbosity=1)

        if args.verbosity > 0:
            print 'Predicting...'

        # predict responses of cell i
        predictions = predict(data_test, results, verbosity=0)

        for entry1, entry2 in zip(data_test, predictions):
            entry1['predictions'] = entry2['predictions']

    # remove data except predictions
    for entry in data:
        if 'spikes' in entry:
            del entry['spikes']
        if 'spike_times' in entry:
            del entry['spike_times']
        del entry['calcium']

    # save results
    if args.output.lower().endswith('.mat'):
        savemat(args.output, {'data': data})

    elif args.output.lower().endswith('.xpck'):
        experiment['args'] = args
        experiment['data'] = data
        experiment.save(args.output)

    else:
        with open(args.output, 'w') as handle:
            dump(data, handle, protocol=2)

    return 0
            df_F.append((corrected_trace - ffilt)/ffilt)
            raw_traces.append(trace)
            corr_traces.append(corrected_trace)

        print "running spike extraction... \n"

        #gInfo = pickle.load(open(hdf['raw_data'][session][area].attrs['GRABinfo']))
        frameRate = gInfo['scanFrameRate']
        #print np.array([corr_traces[0]]).shape
        inf = []
        for i in range(n_neurons):
            sys.stdout.write("\rrunning inference on cell: "+str(1+i)+"/"+str(n_neurons))
            sys.stdout.flush()
            data = [{'calcium':np.array([corr_traces[i]]),'fps': frameRate}]
        #data = [{'calcium':np.array([i]),'fps': frameRate} for i in corr_traces]
            inf.append(c2s.predict(c2s.preprocess(data),verbosity=0))

        print 'Saving Data'
        roiInfo['traces'] = np.array(raw_traces)
        roiInfo['corr_traces'] = np.array(corr_traces)
        roiInfo['df_F'] = np.array(df_F)
        roiInfo['spikeRate_inf'] = inf#np.array([i['predictions'] for i in inf])


        roiInfo['info'] = ['traces are raw traces',
                           'corr_traces are neuropil corrected traces',
                           'idxs are x and y coordinates of ROIs',
                           'spikeRate_inf are inferred spike rate using Theis et al 2015',
                           'df_F are neuropil corrected df/F traces',
                           'centres are the locations of the centre of the ROIs',
                           'patches are cut out patches of the mean image around the ROI',
Esempio n. 10
0
    def fit(self, dataset_paths, model_path=None, folds=5, error_margin=2):

        logger = logging.getLogger(funcname())

        if not model_path:

            # Extract traces and spikes from datasets.
            traces = [self.dataset_traces_func(p) for p in dataset_paths]
            spikes = [self.dataset_spikes_func(p) for p in dataset_paths]
            attrs = [self.dataset_attrs_func(p) for p in dataset_paths]
            assert len(traces) == len(spikes) == len(attrs)

            # Populate C2S data dictionaries.
            data = []
            for i in range(len(attrs)):
                for t, s in zip(traces[i], spikes[i]):
                    data.append({'calcium': t[np.newaxis],
                                 'spikes': s[np.newaxis],
                                 'fps': attrs[i]['sample_rate']})

            # Preprocess in parallel. This is a slow process. Using lower
            # fps creates smaller vectors. Large vectors can crash the training.
            pool = Pool(max(1, cpu_count() - 2))
            args = [{'data': [d], 'fps': 10, 'verbosity':0} for d in data]
            data = pool.map(c2s_preprocess_parallel, args)

            # Serialize data.
            data_path = '%s/%d_data.pkl' % (self.cpdir, int(time()))
            fp = open(data_path, 'wb')
            pkl.dump(data, fp)
            fp.close()
            logging.info('Serialized model to %s' % data_path)

        else:
            fp = open(model_path, 'rb')
            data = pkl.load(fp)
            fp.close()

        import pdb
        pdb.set_trace()

        # Train.
        results = c2s.train(data)

        # Predict.
        data_trn = c2s.predict(data, results)

        # Evaluate using C2S metrics.
        downsample_factor = 10  # fps = 100 -> fps = 10.
        corr = np.nan_to_num(c2s.evaluate(
            data, 'corr', downsampling=downsample_factor), copy=False)
        print('Corr = %.5lf' % np.mean(corr))

        # # Compute metrics.
        # p, r = 0., 0.
        # for i, d in enumerate(data_trn):
        #     yt = d['spikes'][0, np.newaxis]
        #     yp = np.clip(d['predictions'][0, np.newaxis].round(), 0, 1)
        #
        #     p += np2k(prec_margin, yt, yp, margin=error_margin)
        #     r += np2k(reca_margin, yt, yp, margin=error_margin)
        #
        #     if i % 50 == 0 or i == len(data_trn) - 1:
        #         print '%03d: mean p=%-10.3lf mean r=%-10.3lf' % (i, (p / i), (r / i))
        #
        # p, r = 0., 0.
        # for i, d in enumerate(data_val):
        #     yt = d['spikes'][0, np.newaxis]
        #     yp = np.clip(d['predictions'][0, np.newaxis].round(), 0, 1)
        #
        #     p += np2k(prec_margin, yt, yp, margin=error_margin)
        #     r += np2k(reca_margin, yt, yp, margin=error_margin)
        #
        #     if i % 50 == 0 or i == len(data_val) - 1:
        #         print '%03d: mean p=%-10.3lf mean r=%-10.3lf' % (i, (p / i), (r / i))

        import pdb
        pdb.set_trace()