예제 #1
0
def main(argv):
    parser = ArgumentParser(argv[0], description=__doc__)
    parser.add_argument('input', type=str)
    parser.add_argument('output', type=str, nargs='+')
    parser.add_argument('--filter', '-s', type=int, default=0)
    parser.add_argument(
        '--fps',
        '-f',
        type=float,
        default=100.,
        help='Up- or downsample data to match this sampling rate (100 fps).')
    parser.add_argument('--verbosity', '-v', type=int, default=1)

    args = parser.parse_args(argv[1:])

    # load data
    data = load_data(args.input)

    # preprocess data
    data = preprocess(data,
                      fps=args.fps,
                      filter=args.filter if args.filter > 0 else None,
                      verbosity=args.verbosity)

    for filepath in args.output:
        if filepath.lower().endswith('.mat'):
            # store in MATLAB format
            savemat(filepath, convert({'data': data}))
        else:
            with open(filepath, 'w') as handle:
                dump(data, handle, protocol=2)

    return 0
def extract_spikes_alt(roiattrs):

    """ Infer approximate spike rates """
    print "\nrunning spike extraction"
    import c2s

    frameRate = 25
    if 'corr_traces' in roiattrs.keys():
        trace_type = 'corr_traces'
    else:
        trace_type = 'traces'
    data = [{'calcium':np.array([i]),'fps': frameRate} for i in roiattrs[trace_type]]
    spkt = c2s.predict(c2s.preprocess(data),verbosity=0)

    nROIs = len(roiattrs['idxs'])
    cFrames = np.array(roiattrs['traces']).shape[1]

    spk_traces = np.zeros([nROIs,cFrames])
    spk_long = []
    for i in range(nROIs):
        spk_traces[i] = np.mean(spkt[i]['predictions'].reshape(-1,4),axis=1)
        spk_long.append(spkt[i]['predictions'])

    roiattrs['spike_inf'] = spk_traces
    roiattrs['spike_long'] = np.squeeze(np.array(spk_long))
    return roiattrs
예제 #3
0
    def spike_traces(self, X, fps):
        try:
            import c2s
        except:
            warn(
                "c2s was not found. You won't be able to populate ExtracSpikes"
            )
        assert self.fetch1[
            'language'] == 'python', "This tuple cannot be computed in python."
        if self.fetch1['spike_method'] == 3:
            N = len(X)
            for i, trace in enumerate(X):
                print('Predicting trace %i/%i' % (i + 1, N))
                tr0 = np.array(trace.pop('trace').squeeze())
                start = notnan(tr0)
                end = notnan(tr0, len(tr0) - 1, increment=-1)
                trace['calcium'] = np.atleast_2d(tr0[start:end + 1])

                trace['fps'] = fps
                data = c2s.preprocess([trace], fps=fps)
                data = c2s.predict(data, verbosity=0)

                tr0[start:end + 1] = data[0].pop('predictions')
                data[0]['rate_trace'] = tr0.T
                data[0].pop('calcium')
                data[0].pop('fps')

                yield data[0]
예제 #4
0
파일: c2s-preprocess.py 프로젝트: cajal/c2s
def main(argv):
	parser = ArgumentParser(argv[0], description=__doc__)
	parser.add_argument('input',             type=str)
	parser.add_argument('output',            type=str, nargs='+')
	parser.add_argument('--filter',    '-s', type=int,   default=0)
	parser.add_argument('--fps',       '-f', type=float, default=100.,
		help='Up- or downsample data to match this sampling rate (100 fps).' )
	parser.add_argument('--verbosity', '-v', type=int,   default=1)

	args = parser.parse_args(argv[1:])

	# load data
	data = load_data(args.input)

	# preprocess data
	data = preprocess(
		data,
		fps=args.fps,
		filter=args.filter if args.filter > 0 else None,
		verbosity=args.verbosity)

	for filepath in args.output:
		if filepath.lower().endswith('.mat'):
			# store in MATLAB format
			savemat(filepath, {'data': data})
		else:
			with open(filepath, 'w') as handle:
				dump(data, handle, protocol=2)

	return 0
예제 #5
0
def main(argv):
    parser = ArgumentParser(argv[0], description=__doc__)
    parser.add_argument('dataset', type=str)
    parser.add_argument('output', type=str, nargs='+')
    parser.add_argument('--model', '-m', type=str, default='')
    parser.add_argument(
        '--preprocess',
        '-p',
        type=int,
        default=0,
        help=
        'If you haven\'t already applied `preprocess` to the data, set to 1 (default: 0).'
    )
    parser.add_argument('--verbosity', '-v', type=int, default=1)

    args = parser.parse_args(argv[1:])

    experiment = Experiment()

    # load data
    data = load_data(args.dataset)

    if args.preprocess:
        # preprocess data
        data = preprocess(data, args.verbosity)

    if args.model:
        # load training results
        results = Experiment(args.model)['models']
    else:
        # use default model
        results = None

    # predict firing rates
    data = predict(data, results, verbosity=args.verbosity)

    # remove data except predictions
    for entry in data:
        if 'spikes' in entry:
            del entry['spikes']
        if 'spike_times' in entry:
            del entry['spike_times']
        del entry['calcium']

    for filepath in args.output:
        if filepath.lower().endswith('.mat'):
            # store in MATLAB format
            savemat(filepath, {'data': data})
        else:
            with open(filepath, 'w') as handle:
                dump(data, handle, protocol=2)

    return 0
예제 #6
0
파일: c2s-predict.py 프로젝트: cajal/c2s
def main(argv):
	parser = ArgumentParser(argv[0], description=__doc__)
	parser.add_argument('dataset',            type=str)
	parser.add_argument('output',             type=str, nargs='+')
	parser.add_argument('--model',      '-m', type=str, default='')
	parser.add_argument('--preprocess', '-p', type=int, default=0,
		help='If you haven\'t already applied `preprocess` to the data, set to 1 (default: 0).')
	parser.add_argument('--verbosity',  '-v', type=int, default=1)

	args = parser.parse_args(argv[1:])

	experiment = Experiment()

	# load data
	data = load_data(args.dataset)

	if args.preprocess:
		# preprocess data
		data = preprocess(data, args.verbosity)

	if args.model:
		# load training results
		results = Experiment(args.model)['models']
	else:
		# use default model
		results = None

	# predict firing rates
	data = predict(data, results, verbosity=args.verbosity)

	# remove data except predictions
	for entry in data:
		if 'spikes' in entry:
			del entry['spikes']
		if 'spike_times' in entry:
			del entry['spike_times']
		del entry['calcium']

	for filepath in args.output:
		if filepath.lower().endswith('.mat'):
			# store in MATLAB format
			savemat(filepath, {'data': data})
		else:
			with open(filepath, 'w') as handle:
				dump(data, handle, protocol=2)

	return 0
예제 #7
0
    def infer_spikes(self, X, dt, trace_name='ca_trace'):
        assert self.fetch1['language'] == 'python', "This tuple cannot be computed in python."
        fps = 1 / dt
        spike_rates = []
        N = len(X)
        for i, trace in enumerate(X):
            print('Predicting trace %i/%i' % (i+1,N))
            trace['calcium'] = trace.pop(trace_name).T
            trace['fps'] = fps

            data = c2s.preprocess([trace], fps=fps)
            data = c2s.predict(data, verbosity=0)
            data[0]['spike_trace'] = data[0].pop('predictions').T
            data[0].pop('calcium')
            data[0].pop('fps')
            spike_rates.append(data[0])
        return spike_rates
예제 #8
0
    def infer_spikes(self, X, dt, trace_name='ca_trace'):
        assert self.fetch1[
            'language'] == 'python', "This tuple cannot be computed in python."
        fps = 1 / dt
        spike_rates = []
        N = len(X)
        for i, trace in enumerate(X):
            print('Predicting trace %i/%i' % (i + 1, N))
            trace['calcium'] = trace.pop(trace_name).T
            trace['fps'] = fps

            data = c2s.preprocess([trace], fps=fps)
            data = c2s.predict(data, verbosity=0)
            data[0]['spike_trace'] = data[0].pop('predictions').T
            data[0].pop('calcium')
            data[0].pop('fps')
            spike_rates.append(data[0])
        return spike_rates
예제 #9
0
def main(argv):
	parser = ArgumentParser(argv[0], description=__doc__)
	parser.add_argument('dataset',                type=str,   nargs='+')
	parser.add_argument('output',                 type=str)
	parser.add_argument('--num_components', '-c', type=int,   default=3)
	parser.add_argument('--num_features',   '-f', type=int,   default=2)
	parser.add_argument('--num_models',     '-m', type=int,   default=4)
	parser.add_argument('--keep_all',       '-k', type=int,   default=1)
	parser.add_argument('--finetune',       '-n', type=int,   default=0)
	parser.add_argument('--num_valid',      '-s', type=int,   default=0)
	parser.add_argument('--var_explained',  '-e', type=float, default=95.)
	parser.add_argument('--window_length',  '-w', type=float, default=1000.)
	parser.add_argument('--regularize',     '-r', type=float, default=0.)
	parser.add_argument('--preprocess',     '-p', type=int,   default=0)
	parser.add_argument('--verbosity',      '-v', type=int,   default=1)

	args, _ = parser.parse_known_args(argv[1:])

	experiment = Experiment()

	# load data
	data = []
	for dataset in args.dataset:
		data = data + load_data(dataset)

	# preprocess data
	if args.preprocess:
		data = preprocess(data)

	# list of all cells
	if 'cell_num' in data[0]:
		# several trials/entries may belong to the same cell
		cells = unique([entry['cell_num'] for entry in data])
	else:
		# one cell corresponds to one trial/entry
		cells = range(len(data))
		for i in cells:
			data[i]['cell_num'] = i

	for i in cells:
		data_train = [entry for entry in data if entry['cell_num'] != i]
		data_test = [entry for entry in data if entry['cell_num'] == i]

		if args.verbosity > 0:
			print 'Test cell: {0}'.format(i)

		# train on all cells but cell i
		results = train(
			data=data_train,
			num_valid=args.num_valid,
			num_models=args.num_models,
			var_explained=args.var_explained,
			window_length=args.window_length,
			keep_all=args.keep_all,
			finetune=args.finetune,
			model_parameters={
					'num_components': args.num_components,
					'num_features': args.num_features},
			training_parameters={
				'verbosity': 0},
			regularize=args.regularize,
			verbosity=1)

		if args.verbosity > 0:
			print 'Predicting...'

		# predict responses of cell i
		predictions = predict(data_test, results, verbosity=0)

		for entry1, entry2 in zip(data_test, predictions):
			entry1['predictions'] = entry2['predictions']

	# remove data except predictions
	for entry in data:
		if 'spikes' in entry:
			del entry['spikes']
		if 'spike_times' in entry:
			del entry['spike_times']
		del entry['calcium']

	# save results
	if args.output.lower().endswith('.mat'):
		savemat(args.output, convert({'data': data}))

	elif args.output.lower().endswith('.xpck'):
		experiment['args'] = args
		experiment['data'] = data
		experiment.save(args.output)

	else:
		with open(args.output, 'w') as handle:
			dump(data, handle, protocol=2)

	return 0
예제 #10
0
def main(argv):
    parser = ArgumentParser(argv[0], description=__doc__)
    parser.add_argument('dataset', type=str, nargs='+')
    parser.add_argument('output', type=str)
    parser.add_argument('--num_components', '-c', type=int, default=3)
    parser.add_argument('--num_features', '-f', type=int, default=2)
    parser.add_argument('--num_models', '-m', type=int, default=4)
    parser.add_argument('--keep_all', '-k', type=int, default=1)
    parser.add_argument('--finetune', '-n', type=int, default=0)
    parser.add_argument('--num_valid', '-s', type=int, default=0)
    parser.add_argument('--var_explained', '-e', type=float, default=95.)
    parser.add_argument('--window_length', '-w', type=float, default=1000.)
    parser.add_argument('--regularize', '-r', type=float, default=0.)
    parser.add_argument('--preprocess', '-p', type=int, default=0)
    parser.add_argument('--verbosity', '-v', type=int, default=1)

    args, _ = parser.parse_known_args(argv[1:])

    experiment = Experiment()

    # load data
    data = []
    for dataset in args.dataset:
        data = data + load_data(dataset)

    # preprocess data
    if args.preprocess:
        data = preprocess(data)

    # list of all cells
    if 'cell_num' in data[0]:
        # several trials/entries may belong to the same cell
        cells = unique([entry['cell_num'] for entry in data])
    else:
        # one cell corresponds to one trial/entry
        cells = range(len(data))
        for i in cells:
            data[i]['cell_num'] = i

    for i in cells:
        data_train = [entry for entry in data if entry['cell_num'] != i]
        data_test = [entry for entry in data if entry['cell_num'] == i]

        if args.verbosity > 0:
            print 'Test cell: {0}'.format(i)

        # train on all cells but cell i
        results = train(data=data_train,
                        num_valid=args.num_valid,
                        num_models=args.num_models,
                        var_explained=args.var_explained,
                        window_length=args.window_length,
                        keep_all=args.keep_all,
                        finetune=args.finetune,
                        model_parameters={
                            'num_components': args.num_components,
                            'num_features': args.num_features
                        },
                        training_parameters={'verbosity': 0},
                        regularize=args.regularize,
                        verbosity=1)

        if args.verbosity > 0:
            print 'Predicting...'

        # predict responses of cell i
        predictions = predict(data_test, results, verbosity=0)

        for entry1, entry2 in zip(data_test, predictions):
            entry1['predictions'] = entry2['predictions']

    # remove data except predictions
    for entry in data:
        if 'spikes' in entry:
            del entry['spikes']
        if 'spike_times' in entry:
            del entry['spike_times']
        del entry['calcium']

    # save results
    if args.output.lower().endswith('.mat'):
        savemat(args.output, {'data': data})

    elif args.output.lower().endswith('.xpck'):
        experiment['args'] = args
        experiment['data'] = data
        experiment.save(args.output)

    else:
        with open(args.output, 'w') as handle:
            dump(data, handle, protocol=2)

    return 0
예제 #11
0
def main(argv):
    parser = ArgumentParser(argv[0], description=__doc__)
    parser.add_argument('dataset', type=str)
    parser.add_argument('--preprocess', '-p', type=int, default=0)
    parser.add_argument('--output', '-o', type=str, default='')
    parser.add_argument('--seconds', '-S', type=int, default=60)
    parser.add_argument('--offset', '-O', type=int, default=0)
    parser.add_argument('--width', '-W', type=int, default=10)
    parser.add_argument('--height', '-H', type=int, default=0)
    parser.add_argument('--cells', '-c', type=int, default=[], nargs='+')
    parser.add_argument('--dpi', '-D', type=int, default=100)
    parser.add_argument('--font', '-F', type=str, default='Arial')

    args = parser.parse_args(argv[1:])

    # load data
    data = load_data(args.dataset)
    cells = args.cells if args.cells else range(1, len(data) + 1)
    data = [data[c - 1] for c in cells]

    if args.preprocess:
        data = preprocess(data)

    plt.rcParams['font.family'] = args.font
    plt.rcParams['savefig.dpi'] = args.dpi

    plt.figure(figsize=(args.width,
                        args.height if args.height > 0 else len(data) * 1.5 +
                        .3))

    for k, entry in enumerate(data):
        offset = int(entry['fps'] * args.offset)
        length = int(entry['fps'] * args.seconds)
        calcium = entry['calcium'].ravel()[offset:offset + length]

        plt.subplot(len(data), 1, k + 1)
        plt.plot(args.offset + arange(calcium.size) / entry['fps'],
                 calcium,
                 color=(.1, .6, .4))

        if 'spike_times' in entry:
            spike_times = entry['spike_times'].ravel() / 1000.
            spike_times = spike_times[logical_and(
                spike_times > args.offset,
                spike_times < args.offset + args.seconds)]

            for st in spike_times:
                plt.plot([st, st], [-1, -.5], 'k', lw=1.5)

        plt.yticks([])
        plt.ylim([-2., 5.])
        plt.xlim([args.offset, args.offset + args.seconds])
        plt.ylabel('Cell {0}'.format(cells[k]))
        plt.grid()

        if k < len(data) - 1:
            plt.xticks(plt.xticks()[0], [])

    plt.xlabel('Time [seconds]')
    plt.tight_layout()

    if args.output:
        plt.savefig(args.output)
    else:
        plt.show()

    return 0
예제 #12
0
파일: c2s-train.py 프로젝트: fabiansinz/c2s
def main(argv):
	parser = ArgumentParser(argv[0], description=__doc__,
		formatter_class=lambda prog: HelpFormatter(prog, max_help_position=10, width=120))
	parser.add_argument('dataset',                type=str, nargs='+',
		help='Dataset(s) used for training.')
	parser.add_argument('output',                 type=str,
		help='Directory or file where trained models will be stored.')
	parser.add_argument('--num_components', '-c', type=int,   default=3,
		help='Number of components used in STM model (default: %(default)d).')
	parser.add_argument('--num_features',   '-f', type=int,   default=2,
		help='Number of quadratic features used in STM model (default: %(default)d).')
	parser.add_argument('--num_models',     '-m', type=int,   default=4,
		help='Number of models trained (predictions will be averaged across models, default: %(default)d).')
	parser.add_argument('--keep_all',       '-k', type=int,   default=1,
		help='If set to 0, only the best model of all trained models is kept (default: %(default)d).')
	parser.add_argument('--finetune',       '-n', type=int,   default=0,
		help='If set to 1, enables another finetuning step which is performed after training (default: %(default)d).')
	parser.add_argument('--num_train',      '-t', type=int,   default=0,
		help='If specified, a (random) subset of cells is used for training.')
	parser.add_argument('--num_valid',      '-s', type=int,   default=0,
		help='If specified, a (random) subset of cells will be used for early stopping based on validation error.')
	parser.add_argument('--var_explained',  '-e', type=float, default=95.,
		help='Controls the degree of dimensionality reduction of fluorescence windows (default: %(default).0f).')
	parser.add_argument('--window_length',  '-w', type=float, default=1000.,
		help='Length of windows extracted from calcium signal for prediction (in milliseconds, default: %(default).0f).')
	parser.add_argument('--regularize',     '-r', type=float, default=0.,
		help='Amount of parameter regularization (filters are regularized for smoothness, default: %(default).1f).')
	parser.add_argument('--preprocess',     '-p', type=int,   default=0,
		help='If the data is not already preprocessed, this can be used to do it.')
	parser.add_argument('--verbosity',      '-v', type=int,   default=1)

	args, _ = parser.parse_known_args(argv[1:])

	experiment = Experiment()

	if not args.dataset:
		print 'You have to specify at least 1 dataset.'
		return 0

	data = []
	for dataset in args.dataset:
		with open(dataset) as handle:
			data = data + load(handle)

	if args.preprocess:
		data = preprocess(data, args.verbosity)

	if 'cell_num' not in data[0]:
		# no cell number is given, assume traces correspond to cells
		for k, entry in enumerate(data):
			entry['cell_num'] = k

	# collect cell ids
	cell_ids = unique([entry['cell_num'] for entry in data])
	
	# pick cells for training
	if args.num_train > 0:
		training_cells = random_select(args.num_train, len(cell_ids))
	else:
		# use all cells for training
		training_cells = range(len(cell_ids))

	models = train([entry for entry in data if entry['cell_num'] in training_cells],
		num_valid=args.num_valid,
		num_models=args.num_models,
		var_explained=args.var_explained,
		window_length=args.window_length,
		keep_all=args.keep_all,
		finetune=args.finetune,
		model_parameters={
			'num_components': args.num_components,
			'num_features': args.num_features},
		training_parameters={
			'verbosity': 1},
		regularize=args.regularize,
		verbosity=args.verbosity)

	experiment['args'] = args
	experiment['training_cells'] = training_cells
	experiment['models'] = models

	if os.path.isdir(args.output):
		experiment.save(os.path.join(args.output, 'model.xpck'))
	else:
		experiment.save(args.output)

	return 0
예제 #13
0
def main(argv):
	parser = ArgumentParser(argv[0], description=__doc__)
	parser.add_argument('dataset',            type=str)
	parser.add_argument('--preprocess', '-p', type=int, default=0)
	parser.add_argument('--output',     '-o', type=str, default='')
	parser.add_argument('--seconds',    '-S', type=int, default=60)
	parser.add_argument('--offset',     '-O', type=int, default=0)
	parser.add_argument('--width',      '-W', type=int, default=10)
	parser.add_argument('--height',     '-H', type=int, default=0)
	parser.add_argument('--cells',      '-c', type=int, default=[], nargs='+')
	parser.add_argument('--dpi',        '-D', type=int, default=100)
	parser.add_argument('--font',       '-F', type=str, default='Arial')

	args = parser.parse_args(argv[1:])

	# load data
	data = load_data(args.dataset)
	cells = args.cells if args.cells else range(1, len(data) + 1)
	data = [data[c - 1] for c in cells]

	if args.preprocess:
		data = preprocess(data)

	plt.rcParams['font.family'] = args.font
	plt.rcParams['savefig.dpi'] = args.dpi

	plt.figure(figsize=(
		args.width,
		args.height if args.height > 0 else len(data) * 1.5 + .3))

	for k, entry in enumerate(data):
		offset = int(entry['fps'] * args.offset)
		length = int(entry['fps'] * args.seconds)
		calcium = entry['calcium'].ravel()[offset:offset + length]

		plt.subplot(len(data), 1, k + 1)
		plt.plot(args.offset + arange(calcium.size) / entry['fps'], calcium,
			color=(.1, .6, .4))

		if 'spike_times' in entry:
			spike_times = entry['spike_times'].ravel() / 1000.
			spike_times = spike_times[logical_and(
				spike_times > args.offset,
				spike_times < args.offset + args.seconds)]

			for st in spike_times:
				plt.plot([st, st], [-1, -.5], 'k', lw=1.5)

		plt.yticks([])
		plt.ylim([-2., 5.])
		plt.xlim([args.offset, args.offset + args.seconds])
		plt.ylabel('Cell {0}'.format(cells[k]))
		plt.grid()

		if k < len(data) - 1:
			plt.xticks(plt.xticks()[0], [])

	plt.xlabel('Time [seconds]')
	plt.tight_layout()

	if args.output:
		plt.savefig(args.output)
	else:
		plt.show()

	return 0
            df_F.append((corrected_trace - ffilt)/ffilt)
            raw_traces.append(trace)
            corr_traces.append(corrected_trace)

        print "running spike extraction... \n"

        #gInfo = pickle.load(open(hdf['raw_data'][session][area].attrs['GRABinfo']))
        frameRate = gInfo['scanFrameRate']
        #print np.array([corr_traces[0]]).shape
        inf = []
        for i in range(n_neurons):
            sys.stdout.write("\rrunning inference on cell: "+str(1+i)+"/"+str(n_neurons))
            sys.stdout.flush()
            data = [{'calcium':np.array([corr_traces[i]]),'fps': frameRate}]
        #data = [{'calcium':np.array([i]),'fps': frameRate} for i in corr_traces]
            inf.append(c2s.predict(c2s.preprocess(data),verbosity=0))

        print 'Saving Data'
        roiInfo['traces'] = np.array(raw_traces)
        roiInfo['corr_traces'] = np.array(corr_traces)
        roiInfo['df_F'] = np.array(df_F)
        roiInfo['spikeRate_inf'] = inf#np.array([i['predictions'] for i in inf])


        roiInfo['info'] = ['traces are raw traces',
                           'corr_traces are neuropil corrected traces',
                           'idxs are x and y coordinates of ROIs',
                           'spikeRate_inf are inferred spike rate using Theis et al 2015',
                           'df_F are neuropil corrected df/F traces',
                           'centres are the locations of the centre of the ROIs',
                           'patches are cut out patches of the mean image around the ROI',
예제 #15
0
def c2s_preprocess_parallel(argsdict):
    logger = logging.getLogger(funcname())
    logger.info('%d start' % os.getpid())
    if len(argsdict['data']) > 1:
        return c2s.preprocess(**argsdict)
    return c2s.preprocess(**argsdict)[0]