Exemplo n.º 1
0
def main(argv):
	parser = ArgumentParser(argv[0], description=__doc__)
	parser.add_argument('dataset',                type=str,   nargs='+')
	parser.add_argument('output',                 type=str)
	parser.add_argument('--num_components', '-c', type=int,   default=3)
	parser.add_argument('--num_features',   '-f', type=int,   default=2)
	parser.add_argument('--num_models',     '-m', type=int,   default=4)
	parser.add_argument('--keep_all',       '-k', type=int,   default=1)
	parser.add_argument('--finetune',       '-n', type=int,   default=0)
	parser.add_argument('--num_valid',      '-s', type=int,   default=0)
	parser.add_argument('--var_explained',  '-e', type=float, default=95.)
	parser.add_argument('--window_length',  '-w', type=float, default=1000.)
	parser.add_argument('--regularize',     '-r', type=float, default=0.)
	parser.add_argument('--preprocess',     '-p', type=int,   default=0)
	parser.add_argument('--verbosity',      '-v', type=int,   default=1)

	args, _ = parser.parse_known_args(argv[1:])

	experiment = Experiment()

	# load data
	data = []
	for dataset in args.dataset:
		data = data + load_data(dataset)

	# preprocess data
	if args.preprocess:
		data = preprocess(data)

	# list of all cells
	if 'cell_num' in data[0]:
		# several trials/entries may belong to the same cell
		cells = unique([entry['cell_num'] for entry in data])
	else:
		# one cell corresponds to one trial/entry
		cells = range(len(data))
		for i in cells:
			data[i]['cell_num'] = i

	for i in cells:
		data_train = [entry for entry in data if entry['cell_num'] != i]
		data_test = [entry for entry in data if entry['cell_num'] == i]

		if args.verbosity > 0:
			print 'Test cell: {0}'.format(i)

		# train on all cells but cell i
		results = train(
			data=data_train,
			num_valid=args.num_valid,
			num_models=args.num_models,
			var_explained=args.var_explained,
			window_length=args.window_length,
			keep_all=args.keep_all,
			finetune=args.finetune,
			model_parameters={
					'num_components': args.num_components,
					'num_features': args.num_features},
			training_parameters={
				'verbosity': 0},
			regularize=args.regularize,
			verbosity=1)

		if args.verbosity > 0:
			print 'Predicting...'

		# predict responses of cell i
		predictions = predict(data_test, results, verbosity=0)

		for entry1, entry2 in zip(data_test, predictions):
			entry1['predictions'] = entry2['predictions']

	# remove data except predictions
	for entry in data:
		if 'spikes' in entry:
			del entry['spikes']
		if 'spike_times' in entry:
			del entry['spike_times']
		del entry['calcium']

	# save results
	if args.output.lower().endswith('.mat'):
		savemat(args.output, convert({'data': data}))

	elif args.output.lower().endswith('.xpck'):
		experiment['args'] = args
		experiment['data'] = data
		experiment.save(args.output)

	else:
		with open(args.output, 'w') as handle:
			dump(data, handle, protocol=2)

	return 0
Exemplo n.º 2
0
def main(argv):
	parser = ArgumentParser(argv[0], description=__doc__,
		formatter_class=lambda prog: HelpFormatter(prog, max_help_position=10, width=120))
	parser.add_argument('dataset',                type=str, nargs='+',
		help='Dataset(s) used for training.')
	parser.add_argument('output',                 type=str,
		help='Directory or file where trained models will be stored.')
	parser.add_argument('--num_components', '-c', type=int,   default=3,
		help='Number of components used in STM model (default: %(default)d).')
	parser.add_argument('--num_features',   '-f', type=int,   default=2,
		help='Number of quadratic features used in STM model (default: %(default)d).')
	parser.add_argument('--num_models',     '-m', type=int,   default=4,
		help='Number of models trained (predictions will be averaged across models, default: %(default)d).')
	parser.add_argument('--keep_all',       '-k', type=int,   default=1,
		help='If set to 0, only the best model of all trained models is kept (default: %(default)d).')
	parser.add_argument('--finetune',       '-n', type=int,   default=0,
		help='If set to 1, enables another finetuning step which is performed after training (default: %(default)d).')
	parser.add_argument('--num_train',      '-t', type=int,   default=0,
		help='If specified, a (random) subset of cells is used for training.')
	parser.add_argument('--num_valid',      '-s', type=int,   default=0,
		help='If specified, a (random) subset of cells will be used for early stopping based on validation error.')
	parser.add_argument('--var_explained',  '-e', type=float, default=95.,
		help='Controls the degree of dimensionality reduction of fluorescence windows (default: %(default).0f).')
	parser.add_argument('--window_length',  '-w', type=float, default=1000.,
		help='Length of windows extracted from calcium signal for prediction (in milliseconds, default: %(default).0f).')
	parser.add_argument('--regularize',     '-r', type=float, default=0.,
		help='Amount of parameter regularization (filters are regularized for smoothness, default: %(default).1f).')
	parser.add_argument('--preprocess',     '-p', type=int,   default=0,
		help='If the data is not already preprocessed, this can be used to do it.')
	parser.add_argument('--verbosity',      '-v', type=int,   default=1)

	args, _ = parser.parse_known_args(argv[1:])

	experiment = Experiment()

	if not args.dataset:
		print 'You have to specify at least 1 dataset.'
		return 0

	data = []
	for dataset in args.dataset:
		with open(dataset) as handle:
			data = data + load(handle)

	if args.preprocess:
		data = preprocess(data, args.verbosity)

	if 'cell_num' not in data[0]:
		# no cell number is given, assume traces correspond to cells
		for k, entry in enumerate(data):
			entry['cell_num'] = k

	# collect cell ids
	cell_ids = unique([entry['cell_num'] for entry in data])
	
	# pick cells for training
	if args.num_train > 0:
		training_cells = random_select(args.num_train, len(cell_ids))
	else:
		# use all cells for training
		training_cells = range(len(cell_ids))

	models = train([entry for entry in data if entry['cell_num'] in training_cells],
		num_valid=args.num_valid,
		num_models=args.num_models,
		var_explained=args.var_explained,
		window_length=args.window_length,
		keep_all=args.keep_all,
		finetune=args.finetune,
		model_parameters={
			'num_components': args.num_components,
			'num_features': args.num_features},
		training_parameters={
			'verbosity': 1},
		regularize=args.regularize,
		verbosity=args.verbosity)

	experiment['args'] = args
	experiment['training_cells'] = training_cells
	experiment['models'] = models

	if os.path.isdir(args.output):
		experiment.save(os.path.join(args.output, 'model.xpck'))
	else:
		experiment.save(args.output)

	return 0
Exemplo n.º 3
0
def main(argv):
    parser = ArgumentParser(argv[0], description=__doc__)
    parser.add_argument('dataset', type=str, nargs='+')
    parser.add_argument('output', type=str)
    parser.add_argument('--num_components', '-c', type=int, default=3)
    parser.add_argument('--num_features', '-f', type=int, default=2)
    parser.add_argument('--num_models', '-m', type=int, default=4)
    parser.add_argument('--keep_all', '-k', type=int, default=1)
    parser.add_argument('--finetune', '-n', type=int, default=0)
    parser.add_argument('--num_valid', '-s', type=int, default=0)
    parser.add_argument('--var_explained', '-e', type=float, default=95.)
    parser.add_argument('--window_length', '-w', type=float, default=1000.)
    parser.add_argument('--regularize', '-r', type=float, default=0.)
    parser.add_argument('--preprocess', '-p', type=int, default=0)
    parser.add_argument('--verbosity', '-v', type=int, default=1)

    args, _ = parser.parse_known_args(argv[1:])

    experiment = Experiment()

    # load data
    data = []
    for dataset in args.dataset:
        data = data + load_data(dataset)

    # preprocess data
    if args.preprocess:
        data = preprocess(data)

    # list of all cells
    if 'cell_num' in data[0]:
        # several trials/entries may belong to the same cell
        cells = unique([entry['cell_num'] for entry in data])
    else:
        # one cell corresponds to one trial/entry
        cells = range(len(data))
        for i in cells:
            data[i]['cell_num'] = i

    for i in cells:
        data_train = [entry for entry in data if entry['cell_num'] != i]
        data_test = [entry for entry in data if entry['cell_num'] == i]

        if args.verbosity > 0:
            print 'Test cell: {0}'.format(i)

        # train on all cells but cell i
        results = train(data=data_train,
                        num_valid=args.num_valid,
                        num_models=args.num_models,
                        var_explained=args.var_explained,
                        window_length=args.window_length,
                        keep_all=args.keep_all,
                        finetune=args.finetune,
                        model_parameters={
                            'num_components': args.num_components,
                            'num_features': args.num_features
                        },
                        training_parameters={'verbosity': 0},
                        regularize=args.regularize,
                        verbosity=1)

        if args.verbosity > 0:
            print 'Predicting...'

        # predict responses of cell i
        predictions = predict(data_test, results, verbosity=0)

        for entry1, entry2 in zip(data_test, predictions):
            entry1['predictions'] = entry2['predictions']

    # remove data except predictions
    for entry in data:
        if 'spikes' in entry:
            del entry['spikes']
        if 'spike_times' in entry:
            del entry['spike_times']
        del entry['calcium']

    # save results
    if args.output.lower().endswith('.mat'):
        savemat(args.output, {'data': data})

    elif args.output.lower().endswith('.xpck'):
        experiment['args'] = args
        experiment['data'] = data
        experiment.save(args.output)

    else:
        with open(args.output, 'w') as handle:
            dump(data, handle, protocol=2)

    return 0
Exemplo n.º 4
0
    def fit(self, dataset_paths, model_path=None, folds=5, error_margin=2):

        logger = logging.getLogger(funcname())

        if not model_path:

            # Extract traces and spikes from datasets.
            traces = [self.dataset_traces_func(p) for p in dataset_paths]
            spikes = [self.dataset_spikes_func(p) for p in dataset_paths]
            attrs = [self.dataset_attrs_func(p) for p in dataset_paths]
            assert len(traces) == len(spikes) == len(attrs)

            # Populate C2S data dictionaries.
            data = []
            for i in range(len(attrs)):
                for t, s in zip(traces[i], spikes[i]):
                    data.append({'calcium': t[np.newaxis],
                                 'spikes': s[np.newaxis],
                                 'fps': attrs[i]['sample_rate']})

            # Preprocess in parallel. This is a slow process. Using lower
            # fps creates smaller vectors. Large vectors can crash the training.
            pool = Pool(max(1, cpu_count() - 2))
            args = [{'data': [d], 'fps': 10, 'verbosity':0} for d in data]
            data = pool.map(c2s_preprocess_parallel, args)

            # Serialize data.
            data_path = '%s/%d_data.pkl' % (self.cpdir, int(time()))
            fp = open(data_path, 'wb')
            pkl.dump(data, fp)
            fp.close()
            logging.info('Serialized model to %s' % data_path)

        else:
            fp = open(model_path, 'rb')
            data = pkl.load(fp)
            fp.close()

        import pdb
        pdb.set_trace()

        # Train.
        results = c2s.train(data)

        # Predict.
        data_trn = c2s.predict(data, results)

        # Evaluate using C2S metrics.
        downsample_factor = 10  # fps = 100 -> fps = 10.
        corr = np.nan_to_num(c2s.evaluate(
            data, 'corr', downsampling=downsample_factor), copy=False)
        print('Corr = %.5lf' % np.mean(corr))

        # # Compute metrics.
        # p, r = 0., 0.
        # for i, d in enumerate(data_trn):
        #     yt = d['spikes'][0, np.newaxis]
        #     yp = np.clip(d['predictions'][0, np.newaxis].round(), 0, 1)
        #
        #     p += np2k(prec_margin, yt, yp, margin=error_margin)
        #     r += np2k(reca_margin, yt, yp, margin=error_margin)
        #
        #     if i % 50 == 0 or i == len(data_trn) - 1:
        #         print '%03d: mean p=%-10.3lf mean r=%-10.3lf' % (i, (p / i), (r / i))
        #
        # p, r = 0., 0.
        # for i, d in enumerate(data_val):
        #     yt = d['spikes'][0, np.newaxis]
        #     yp = np.clip(d['predictions'][0, np.newaxis].round(), 0, 1)
        #
        #     p += np2k(prec_margin, yt, yp, margin=error_margin)
        #     r += np2k(reca_margin, yt, yp, margin=error_margin)
        #
        #     if i % 50 == 0 or i == len(data_val) - 1:
        #         print '%03d: mean p=%-10.3lf mean r=%-10.3lf' % (i, (p / i), (r / i))

        import pdb
        pdb.set_trace()