Esempio n. 1
0
    def refresh_batches(self):

        sql = "SELECT * FROM Analysis order by id"
        self.analysis_data = nd.pd_query(sql)
        model = QtGui.QStandardItemModel()

        for i in self.analysis_data['name'].to_list():
            item = QtGui.QStandardItem(i)
            model.appendRow(item)
        self.comboAnalysis.setModel(model)

        index = self.comboAnalysis.findText(self.current_analysis,
                                            QtCore.Qt.MatchFixedString)
        if index >= 0:
            self.comboAnalysis.setCurrentIndex(index)

        sql = "SELECT DISTINCT batch FROM Batches order by batch"
        self.batch_data = nd.pd_query(sql)
        model = QtGui.QStandardItemModel()
        for i in self.batch_data['batch'].to_list():
            item = QtGui.QStandardItem(i)
            model.appendRow(item)
        self.comboBatch.setModel(model)

        self.analysis_update()
Esempio n. 2
0
    def update_widgets(self):

        batch = int(self.batchLE.text())
        cellmask = self.cellLE.text() + "%"
        # strip white spaces and split by semicolon
        modeltext = self.modelLE.text().replace(' ', '').split(';')
        # attach % wildcard before and after modelstring
        modelmask = ['%' + m + '%' for m in modeltext]
        #modelmask = "%" + self.modelLE.text() + "%"

        save_settings(self)

        if batch > 0:
            self.batch = batch
        else:
            self.batchLE.setText(str(self.batch))
        #" ORDER BY cellid",

        #self.d_cells = nd.get_batch_cells(self.batch, cellid=cellmask)

        self.d_cells = nd.pd_query("SELECT DISTINCT cellid FROM Results" +
                               " WHERE batch=%s AND cellid like %s" +
                               " ORDER BY cellid",
                               (self.batch, cellmask))

        modelquery = ("SELECT modelname, count(*) as n, max(lastmod) as "
                      "last_mod FROM Results WHERE batch=%s AND ")
        for i, m in enumerate(modelmask):
            modelquery += 'modelname like %s OR '
        modelquery = modelquery[:-3]  # drop the trailing OR
        modelquery += 'GROUP BY modelname ORDER BY modelname'
        self.d_models = nd.pd_query(modelquery, (self.batch, *modelmask))

#        self.d_models = nd.pd_query("SELECT modelname, count(*) as n, max(lastmod) as last_mod FROM Results" +
#                               " WHERE batch=%s AND modelname like %s" +
#                               " GROUP BY modelname ORDER BY modelname",
#                               (self.batch, modelmask))

        self.cells.clear()
        for c in list(self.d_cells['cellid']):
            list_item = qw.QListWidgetItem(c, self.cells)

        self.models.clear()
        for m in list(self.d_models['modelname']):
            list_item = qw.QListWidgetItem(m, self.models)
        #self.data_model._df = self.d_models
        #self.data_table.setModel(self.data_model)

        print('updated list widgets')
Esempio n. 3
0
    def save_analysis(self):
        video_name = self.video_name.get()

        fn = video_name + '.pickle'
        fn_mat = video_name + '.mat'
        fp = os.path.split(self.processed_video)[0]
        save_path = os.path.join(fp, fn)
        # for matlab loading
        mat_fn = os.path.join(fp, fn_mat)

        sorted_dir = os.path.split(save_path)[0]

        if os.path.isdir(sorted_dir) != True:
            # create sorted directory and force to be world writeable
            os.system("mkdir {}".format(sorted_dir))
            os.system("chmod a+w {}".format(sorted_dir))
            print("created new directory {0}".format(sorted_dir))
        else:
            pass

        save_dict = self.parms

        # add excluded frames to the save dictionary
        excluded_frames = np.concatenate(
            (np.array(self.exclude_starts)[np.newaxis, :],
             np.array(self.exclude_ends)[np.newaxis, :]),
            axis=0)
        save_dict['cnn']['excluded_frames'] = excluded_frames.T

        print("computing eyespeed")
        x_diff = np.diff(save_dict['cnn']['x'])
        y_diff = np.diff(save_dict['cnn']['y'])
        d = np.sqrt((x_diff**2) + (y_diff**2))
        d[-1] = 0
        d = np.concatenate((d, np.zeros(1)))
        save_dict['cnn']['eyespeed'] = d

        with open(save_path, 'wb') as fp:
            pickle.dump(save_dict, fp, protocol=pickle.HIGHEST_PROTOCOL)

        scipy.io.savemat(mat_fn, save_dict)

        # finally, update celldb to mark pupil as analyzed
        # see if the eyecalfile is correct (migth've been a flush error, in which case it'll still say L:/)
        get_file1 = "SELECT eyecalfile from gDataRaw where eyecalfile='{0}'".format(
            self.raw_video)
        out1 = nd.pd_query(get_file1)
        if out1.shape[0] == 0:
            # try the L:/ path
            og_video_path = self.raw_video.replace('/auto/data/daq/', 'L:/')
            sql = "UPDATE gDataRaw SET eyewin=2 WHERE eyecalfile='{}'".format(
                og_video_path)
        else:
            sql = "UPDATE gDataRaw SET eyewin=2 WHERE eyecalfile='{}'".format(
                self.raw_video)
        nd.sql_command(sql)

        print("saved analysis successfully")
Esempio n. 4
0
def get_site_data(siteid):
    """
    :param siteid:
    :return:
    """
    sql = f"SELECT * FROM gCellMaster WHERE siteid like '{siteid}'"
    d = db.pd_query(sql)

    return d
Esempio n. 5
0
def get_single_cell_data(cellid):
    """
    :param cellid: single cellid or siteid
    :return:
    """
    sql = f"SELECT * FROM gSingleCell WHERE cellid like '{cellid}%%'"
    d = db.pd_query(sql)

    return d
Esempio n. 6
0
    def load_file(self):
        """
        Load the overall predictions and plot the trace on the trace canvas.
        Display the first frame of the video on the pupil canvas.
        """

        params_file = self.video_name.get()
        # get raw video -- try to use the exisiting path from raw video
        fp = os.path.split(self.raw_video)[0]

        self.processed_video = os.path.join(fp, 'sorted', params_file)

        # reset raw video attribute
        ext = self.raw_video.split('.')[-1]
        if len(self.raw_video.split('.')) > 2:
            ext2 = self.raw_video.split('.')[-2]
            self.raw_video = os.path.join(fp,
                                          params_file) + '.' + ext2 + '.' + ext
        else:
            self.raw_video = os.path.join(fp, params_file) + '.' + ext
        print(self.raw_video)

        self.plot_trace(params_file)
        try:
            self.plot_eyelid_movement(params_file)
        except:
            print("Couldn't load eyelid keypoints -- old fit?")
            pass
        self.frame_n_value.insert(0, str(0))

        # reset exclusion frames
        self.exclude_starts = []
        self.exclude_ends = []

        # save first ten frames and display the first
        video = self.raw_video

        # first make sure the tmp file doesn't exist for this user, just to avoid asking for overwrite permissions
        os.system(f"rm {tmp_frame_folder}frame1_{getpass.getuser()}.jpg")
        os.system(
            f"ffmpeg -ss 00:00:00 -i {video} -vframes 1 {tmp_frame_folder}frame1_{getpass.getuser()}.jpg"
        )

        frame_file = tmp_frame_folder + f'frame1_{getpass.getuser()}.jpg'

        # define the species for this animal by querying the database
        self.species = nd.pd_query(
            f"SELECT species from gAnimal where animal='{self.animal_name.get()}'"
        ).values[0][0]

        self.plot_frame(frame_file)
        self.master.mainloop()
Esempio n. 7
0
def _queue_fits(batch, modelnames, iterator):
    for siteid in iterator:
        for modelname in modelnames:
            do_fit = True
            if not FORCE_RERUN:
                d = nd.pd_query(
                    "SELECT * FROM Results WHERE cellid like %s and modelname=%s and batch=%s",
                    params=(siteid + "%", modelname, batch))
                if len(d) > 0:
                    do_fit = False
                    print(f'Fit exists for {siteid} {batch} {modelname}')
            if do_fit:
                enqueue_exacloud_models(cellist=[siteid],
                                        batch=batch,
                                        modellist=[modelname],
                                        useGPU=True,
                                        **EXACLOUD_SETTINGS)
Esempio n. 8
0
    def list_all_recordings(self):
        print('listing all taggable sites...')
        df = nd.pd_query(
            "SELECT * FROM gSingleCell INNER JOIN sCellFile ON gSingleCell.cellid = sCellFile.cellid"
            + " WHERE sCellFile.RunClassid = 51 AND sCellFile.cellid LIKE %s",
            params=("TNC%", ))

        # clean up DF
        DF = pd.DataFrame()
        DF['cellid'] = df.cellid.iloc[:, 0]
        DF['siteid'] = df.siteid
        DF['recording'] = df.stimfile.apply(lambda x: x.split('.')[0])
        DF['parmfile'] = df.stimpath + df.stimfile  # full path to parameter file.
        DF['rawid'] = df.rawid.iloc[:, 1]

        self.recordings = DF.recording.unique()
        self.DF = DF
        print('done')
Esempio n. 9
0
def load_existing_pred(cellid=None,
                       siteid=None,
                       batch=None,
                       modelname_existing=None,
                       **kwargs):
    """
    designed to be called by xforms keyword loadpred 
    cellid/siteid - one or the other required
    batch - required
    default modelname_existing = "psth.fs4.pup-ld-st.pup-hrc-psthfr-aev_sdexp2.SxR_newtf.n.lr1e4.cont.et5.i50000"
    
    makes new signal 'pred0' from evaluated 'pred', returns in updated rec
    returns ctx-compatible dict {'rec': nems.Recording, 'input_name': 'pred0'}
    """
    if (batch is None):
        raise ValueError("must specify cellid/siteid and batch")

    if cellid is None:
        if siteid is None:
            raise ValueError("must specify cellid/siteid and batch")
        d = nd.pd_query(
            "SELECT batch,cellid FROM Batches WHERE batch=%s AND cellid like %s",
            (
                batch,
                siteid + "%",
            ))
        cellid = d['cellid'].values[0]
    elif type(cellid) is list:
        cellid = cellid[0]

    if modelname_existing is None:
        #modelname_existing = "psth.fs4.pup-ld-st.pup-hrc-psthfr-aev_sdexp2.SxR_newtf.n.lr1e4.cont"
        modelname_existing = "psth.fs4.pup-ld-st.pup-hrc-psthfr-aev_sdexp2.SxR_newtf.n.lr1e4.cont.et5.i50000"

    xf, ctx = xhelp.load_model_xform(cellid, batch, modelname_existing)
    for k in ctx['val'].signals.keys():
        if k not in ctx['rec'].signals.keys():
            ctx['rec'].signals[k] = ctx['val'].signals[k].copy()
    s = ctx['rec']['pred'].copy()
    s.name = 'pred0'
    ctx['rec'].add_signal(s)

    #return {'rec': ctx['rec'],'val': ctx['val'],'est': ctx['est']}
    return {'rec': ctx['rec'], 'input_name': 'pred0'}
Esempio n. 10
0
    def reload_models(self):
        from nems_web.utilities.ModelFinder import ModelFinder

        t = self.comboBatch.currentText()
        session = Session()
        Analysis = Tables()['Analysis']

        modeltree = (session.query(Analysis.modeltree).filter(
            Analysis.name == self.current_analysis).first())
        modelextras = (session.query(Analysis.model_extras).filter(
            Analysis.name == self.current_analysis).first())
        # Pass modeltree string from Analysis to a ModelFinder constructor,
        # which will use a series of internal methods to convert the tree string
        # to a list of model names.
        # Then add any additional models specified in extraModels, and add
        # model_lists from extraAnalyses.
        if modeltree and modeltree[0]:
            #model_list = _get_models(modeltree[0])
            load, mod, fit = json.loads(modeltree[0])
            loader = ModelFinder(load).modellist
            model = ModelFinder(mod).modellist
            fitter = ModelFinder(fit).modellist
            combined = itertools.product(loader, model, fitter)
            model_list = ['_'.join(m) for m in combined]
            extraModels = [
                s.strip("\"\n").replace("\\n", "")
                for s in modelextras[0].split(',')
            ]
            model_list.extend(extraModels)
        else:
            model_list = []
        self.all_models = model_list

        #sql = f"SELECT DISTINCT modelname FROM Results WHERE batch={t}"
        #data = nd.pd_query(sql)
        #self.all_models = data['modelname'].to_list()
        sql = f"SELECT DISTINCT cellid FROM Batches WHERE batch={t}"
        data = nd.pd_query(sql)
        self.all_cellids = data['cellid'].to_list()
        self.lastbatch = t

        self.labelBatchName.setText(str(t))
Esempio n. 11
0
def second_fit_pop_models(batch, start_from=None, test_count=None):
    all_cellids = nd.get_batch_cells(batch, as_list=True)
    if batch == 322:
        sites = NAT4_A1_SITES
    else:
        sites = NAT4_PEG_SITES
    cellids = [
        c for c in all_cellids
        if np.any([c.startswith(s.split('.')[0]) for s in sites])
    ]

    modelnames = []
    for k, v in MODELGROUPS.items():
        if ('_single' not in k) and ('_exploration' not in k) and (k != 'LN'):
            modelnames.extend(v)
    iterator = cellids

    for siteid in iterator:
        for modelname in modelnames[start_from:test_count]:
            do_fit = True
            if not FORCE_RERUN:
                d = nd.pd_query(
                    "SELECT * FROM Results WHERE cellid like %s and modelname=%s and batch=%s",
                    params=(siteid + "%", modelname, batch))
                if len(d) > 0:
                    do_fit = False
                    print(f'Fit exists for {siteid} {batch} {modelname}')
            if do_fit:
                nd.enqueue_models(
                    celllist=[siteid],
                    batch=batch,
                    modellist=[modelname],
                    user="******",
                    #executable_path='/auto/users/jacob/bin/anaconda3/envs/jacob_nems/bin/python',
                    executable_path=
                    '/auto/users/svd/bin/miniconda3/envs/tf/bin/python',
                    script_path=
                    '/auto/users/jacob/bin/anaconda3/envs/jacob_nems/nems/scripts/fit_single.py'
                )

    return modelnames
Esempio n. 12
0
def get_training_files(animal, runclass, earliest_date, latest_date=None, pupil=False, min_trials=50):

    an_regex = "%" + animal + "%"

    if latest_date is None:
        latest_date = earliest_date

    # get list of all training parmfiles
    sql = "SELECT parmfile, resppath FROM gDataRaw WHERE runclass=%s and resppath like %s and training = 1 and bad=0 and trials>%s"
    if pupil:
        # require pupil processed
        sql = "SELECT parmfile, resppath FROM gDataRaw WHERE runclass=%s and resppath like %s and training = 1 and bad=0 and trials>%s and eyewin=2"
    parmfiles = nd.pd_query(sql, (runclass, an_regex, min_trials))

    try:
        parmfiles['date'] = [dt.datetime.strptime('-'.join(x.split('_')[1:-2]), '%Y-%m-%d') for x in parmfiles.parmfile]
        ed = dt.datetime.strptime(earliest_date, '%Y_%m_%d')
        ld = dt.datetime.strptime(latest_date, '%Y_%m_%d')
        parmfiles = parmfiles[(parmfiles.date >= ed) & (parmfiles.date <= ld)]
        return parmfiles
    
    except:
        raise ValueError("No files found")
Esempio n. 13
0
"""
Load spike waveform params for all SUs in batch 289 and 294. Cluster based on this.
"""

import nems.db as nd
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans

path = '/auto/users/hellerc/results/nat_pupil_ms/'
cellids_cache = path + 'celltypes.csv'

cellids = pd.DataFrame(pd.concat([nd.get_batch_cells(289), nd.get_batch_cells(294), nd.get_batch_cells(331)]).cellid)
iso_query = f"SELECT cellid, rawid, isolation from gSingleRaw WHERE cellid in {tuple([x for x in cellids.cellid])}"
isolation = nd.pd_query(iso_query)
cellids = pd.DataFrame(data=isolation[isolation.isolation>=95].cellid.unique(), columns=['cellid'])

# keep only SU
sw = [nd.get_gSingleCell_meta(cellid=c, fields='wft_spike_width') for c in cellids.cellid] 
cellids['spike_width'] = sw

# remove cellids that weren't sorted with KS (so don't have waveform stats)
cellids = cellids[cellids.spike_width!=-1]

# now save endslope and peak trough ratio
es = [nd.get_gSingleCell_meta(cellid=c, fields='wft_endslope') for c in cellids.cellid] 
pt = [nd.get_gSingleCell_meta(cellid=c, fields='wft_peak_trough_ratio') for c in cellids.cellid] 

cellids['end_slope'] = es
cellids['peak_trough'] = pt
Esempio n. 14
0
def initialize_with_prefit(modelspec, meta, area="A1", cellid=None, siteid=None, batch=322, pre_batch=None,
                           use_matched=False, use_simulated=False, use_full_model=False, 
                           prefit_type=None, freeze_early=True, IsReload=False, **ctx):
    """
    replace early layers of model with fit parameters from a "standard" model ... for now that's model with the same architecture fit
    to the NAT4 dataset
    
    for dnn single:
    initial model:
    modelname = "ozgf.fs100.ch18-ld-norm.l1-sev_wc.18x4.g-fir.1x25x4-relu.4.f-wc.4x1-lvl.1-dexp.1_tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4.es20"
    
    use initial as pre-fit:
    modelname = "ozgf.fs100.ch18-ld-norm.l1-sev_wc.18x4.g-fir.1x25x4-relu.4.f-wc.4x1-lvl.1-dexp.1_prefit-tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4.es20"

    """
    if IsReload:
        return {}

    xi = find_module("weight_channels", modelspec, find_all_matches=True)
    if len(xi) == 0:
        raise ValueError(f"modelspec has not weight_channels layer to align")

    copy_layers = xi[-1]
    freeze_layer_count = xi[-1]
    batch = int(meta['batch'])
    modelname_parts = meta['modelname'].split("_")
    
    if use_simulated:
        guess = '.'.join(['SIM000a', modelname_parts[1]])

        # remove problematic characters
        guess = re.sub('[:]', '', guess)
        guess = re.sub('[,]', '', guess)
        if len(guess) > 100:
            # If modelname is too long, causes filesystem errors.
            guess = guess[:75] + '...' + str(hashlib.sha1(guess.encode('utf-8')).hexdigest()[:20])

        old_uri = f"/auto/data/nems_db/modelspecs/{guess}/modelspec.0000.json"
        log.info('loading saved modelspec from: ' + old_uri)

        new_ctx = load_phi(modelspec, prefit_uri=old_uri, copy_layers=copy_layers)
        
        return new_ctx

    elif prefit_type == 'init':
        # use full pop file - SVD work in progress. current best?
        load_string_pop = "ozgf.fs100.ch18.pop-loadpop-norm.l1-popev"
        fit_string_pop = "tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4"

        pre_part = load_string_pop
        if len(modelname_parts[2].split("-")) > 2:
            post_part = "-".join(modelname_parts[2].split("-")[1:-1])
        else:
            post_part = "-".join(modelname_parts[2].split("-")[1:])

        model_search = "_".join([pre_part, modelname_parts[1], post_part])

        if pre_batch is None:
            pre_batch = batch
        if pre_batch in [322, 334]:
            pre_cellid = 'ARM029a-07-6'
        elif pre_batch == 323:
            pre_cellid = 'ARM017a-01-9'
        else:
            raise ValueError(f"batch {pre_batch} prefit not implemented yet.")

        log.info(f"prefit cellid={pre_cellid}, skipping init_fit")
        copy_layers = len(modelspec)

    elif use_full_model:
        
        # use full pop file - SVD work in progress. current best?
        load_string_pop = "ozgf.fs100.ch18.pop-loadpop-norm.l1-popev"
        fit_string_pop = "tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4"

        if prefit_type == 'heldout':
            pre_part = "ozgf.fs100.ch18.pop-loadpop.hs-norm.l1-popev"
        elif prefit_type == 'matched':
            pre_part = "ozgf.fs100.ch18.pop-loadpop.hm-norm.l1-popev"
        elif prefit_type == 'matched_half':
            # 50% est data (matched cell excluded)
            pre_part = "ozgf.fs100.ch18.pop-loadpop.hm-norm.l1-popev.k50"
        elif prefit_type == 'matched_quarter':
            # 25% est data
            pre_part = "ozgf.fs100.ch18.pop-loadpop.hm-norm.l1-popev.k25"
        elif prefit_type == 'matched_fifteen':
            # 15% est data
            pre_part = "ozgf.fs100.ch18.pop-loadpop.hm-norm.l1-popev.k15"
        elif prefit_type == 'matched_ten':
            # 10% est data
            pre_part = "ozgf.fs100.ch18.pop-loadpop.hm-norm.l1-popev.k10"
        elif prefit_type == 'heldout_half':
            # 50% est data, cell excluded (is this a useful condition?)
            pre_part = "ozgf.fs100.ch18.pop-loadpop.hs-norm.l1-popev.k50"
        elif prefit_type == 'heldout_quarter':
            # 25% est data
            pre_part = "ozgf.fs100.ch18.pop-loadpop.hs-norm.l1-popev.k25"
        elif prefit_type == 'heldout_fifteen':
            # 15% est data
            pre_part = "ozgf.fs100.ch18.pop-loadpop.hs-norm.l1-popev.k15"
        elif prefit_type == 'heldout_ten':
            # 10% est data
            pre_part = "ozgf.fs100.ch18.pop-loadpop.hs-norm.l1-popev.k10"
        elif 'R.q.s' in modelname_parts[1]:
            pre_part = "ozgf.fs100.ch18-ld-norm.l1-sev"
        elif 'ch32' in modelname_parts[0]:
            pre_part = "ozgf.fs100.ch32.pop-loadpop-norm.l1-popev"
        elif 'ch64' in modelname_parts[0]:
            pre_part = "ozgf.fs100.ch64.pop-loadpop-norm.l1-popev"
        elif batch==333:
            # not pre-concatenated recording. different stim for each site, 
            # so fit each site separately (unless titan)
            pre_part = "ozgf.fs100.ch18-ld-norm.l1-sev"
        else:
            #load_string_pop = "ozgf.fs100.ch18.pop-loadpop-norm.l1-popev"
            pre_part = load_string_pop

        if prefit_type == 'titan':
            if batch==333:
                pre_part = load_string_pop
                post_part = "tfinit.n.mc50.lr1e3.et4.es20-newtf.n.mc100.lr1e4"
            else:
                post_part = "tfinit.n.mc25.lr1e3.es20-newtf.n.mc100.lr1e4.exa"
        else:
            #post_part = "tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4.es20"
            post_part = fit_string_pop

        if modelname_parts[2].endswith(".l2:5") or modelname_parts[2].endswith(".l2:5-dstrf") or modelname_parts[2].endswith("ver5"):
            post_part += ".l2:5"
        elif modelname_parts[2].endswith(".l2:4") or modelname_parts[2].endswith(".l2:4-dstrf") or modelname_parts[2].endswith("ver4"):
            post_part += ".l2:4"
        elif modelname_parts[2].endswith(".l2:4.ver2"):
            post_part += ".l2:4.ver2"
        elif modelname_parts[2].endswith("ver2"):
            post_part += ".ver2"
        elif modelname_parts[2].endswith("ver1"):
            post_part += ".ver1"

        model_search = "_".join([pre_part, modelname_parts[1], post_part])
        if pre_batch is None:
            pre_batch = batch

        # this is a single-cell fit
        if type(cellid) is list:
            cellid = cellid[0]
        siteid = cellid.split("-")[0]
        allsiteids, allcellids = nd.get_batch_sites(batch, modelname_filter=model_search)
        allsiteids = [s.split(".")[0] for s in allsiteids]

        if (batch==323) and (pre_batch==322):
            matchfile=os.path.dirname(__file__) + "/projects/pop_model_scripts/snr_subset_map.csv"
            df = pd.read_csv(matchfile, index_col=0)
            pre_cellid = df.loc[df.PEG_cellid==cellid, 'A1_cellid'].values[0]
        elif (batch==322) and (pre_batch==323):
            matchfile=os.path.dirname(__file__) + "/projects/pop_model_scripts/snr_subset_map.csv"
            df = pd.read_csv(matchfile, index_col=0)
            pre_cellid = df.loc[df.A1_cellid==cellid, 'PEG_cellid'].values[0]

        elif siteid in allsiteids:
            # don't need to generalize, load from actual fit
            pre_cellid = cellid
        elif batch in [322, 334]:
            pre_cellid = 'ARM029a-07-6'
        elif pre_batch == 323:
            pre_cellid = 'ARM017a-01-9'
        else:
            raise ValueError(f"batch {batch} prefit not implemented yet.")
            
        log.info(f"prefit cellid={pre_cellid} prefit batch={pre_batch}")

    elif prefit_type == 'site':
        # exact same model, just fit for site, now being fit for single cell
        pre_parts = modelname_parts[0].split("-")
        post_parts = modelname_parts[2].split("-")
        model_search = modelname_parts[0] + "%%" + modelname_parts[1] + "%%" + "-".join(post_parts[1:])

        pre_cellid = cellid[0]
        pre_batch = batch
    elif prefit_type is not None:
        # this is a single-cell fit
        if type(cellid) is list:
            cellid = cellid[0]
            
        if prefit_type=='heldout':
            if siteid is None:
                siteid=cellid.split("-")[0]
            cellids, this_perf, alt_cellid, alt_perf = _matching_cells(batch=batch, siteid=siteid)

            pre_cellid = [c_alt for c,c_alt in zip(cellids,alt_cellid) if c==cellid][0]
            log.info(f"heldout init for {cellid} is {pre_cellid}")
        else:
            pre_cellid = cellid
            log.info(f"matched cellid prefit for {cellid}")
        if pre_batch is None:
            pre_batch = batch

        post_part = "tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4"
        if modelname_parts[2].endswith(".l2:5") or modelname_parts[2].endswith(".l2:5-dstrf"):
            post_part += ".l2:5"
        elif modelname_parts[2].endswith(".l2:4") or modelname_parts[2].endswith(".l2:4-dstrf"):
            post_part += ".l2:4"
        elif modelname_parts[2].endswith(".l2:4.ver2"):
            post_part += ".l2:4.ver2"
        elif modelname_parts[2].endswith("ver2"):
            post_part += ".ver2"
        modelname_parts[2] = post_part
        model_search="_".join(modelname_parts)

    elif modelname_parts[1].endswith(".1"):
        raise ValueError("deprecated prefit initialization?")
        # this is a single-cell fit
        if type(cellid) is list:
            cellid = cellid[0]
        
        if use_matched:
            # determine matched cell for this heldout cell
            if siteid is None:
                siteid=cellid.split("-")[0]
            cellids, this_perf, alt_cellid, alt_perf = _matching_cells(batch=batch, siteid=siteid)

            pre_cellid = [c_alt for c,c_alt in zip(cellids,alt_cellid) if c==cellid][0]
            log.info(f"matched cell for {cellid} is {pre_cellid}")
        else:
            pre_cellid = cellid[0]
            log.info(f"cellid prefit for {cellid}")
        if pre_batch is None:
            pre_batch = batch
        #postparts = modelname_parts[2].split("-")
        #postparts = [s for s in postparts if not(s.startswith("prefit"))]
        #modelname_parts[2]="-".join(postparts)
        modelname_parts[2] = "tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4.es20"
        model_search="_".join(modelname_parts)

    else:
        pre_parts = modelname_parts[0].split("-")
        post_parts = modelname_parts[2].split("-")    
        post_part = "tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4.ver2"
        model_search = pre_parts[0] + ".pop%%" + modelname_parts[1] + "%%" + post_part

        #ozgf.fs100.ch18.pop-loadpop-norm.l1-popev
        #wc.18x70.g-fir.1x15x70-relu.70.f-wc.70x80-fir.1x10x80-relu.80.f-wc.80x100-relu.100-wc.100xR-lvl.R-dexp.R
        #tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4.ver2


        # hard-coded to use an A1 model!!!!
        if pre_batch == 322:
            pre_cellid = 'ARM029a-07-6'
        elif area == "A1":
            pre_cellid = 'ARM029a-07-6'
            pre_batch = 322
        else:
            raise ValueError(f"area {area} prefit not implemented")

    log.info(f"model_search: {model_search}")

    sql = f"SELECT * FROM Results WHERE batch={pre_batch} and cellid='{pre_cellid}' and modelname like '{model_search}'"
    #log.info(sql)
    
    d = nd.pd_query(sql)
    #old_uri = adjust_uri_prefix(d['modelpath'][0] + '/modelspec.0000.json')
    old_uri = adjust_uri_prefix(d['modelpath'][0])
    log.info(f"Importing parameters from {old_uri}")

    mspaths = [f"{old_uri}/modelspec.{i:04d}.json" for i in range(modelspec.cell_count)]
    print(mspaths)
    prefit_ctx = xforms.load_modelspecs([], uris=mspaths, IsReload=False)

    #_, prefit_ctx = xform_helper.load_model_xform(
    #    cellid=pre_cellid, batch=pre_batch,
    #    modelname=d['modelname'][0], eval_model=False)
    new_ctx = load_phi(modelspec, prefit_modelspec=prefit_ctx['modelspec'], copy_layers=copy_layers)
    if freeze_early:
        new_ctx['freeze_layers'] = list(np.arange(freeze_layer_count))
    if prefit_type == 'init':
        new_ctx['skip_init'] = True
    return new_ctx
Esempio n. 15
0
# set up axes
plt.figure(figsize=(14, 4))
example_ax = plt.subplot2grid((1, 6), (0, 0), colspan=4)
crd_ax = plt.subplot2grid((1, 6), (0, 4), colspan=1)
drx_ax = plt.subplot2grid((1, 6), (0, 5), colspan=1)
median = True

# ========================================= EXAMPLE SESSION ===========================================================
# define example session
filename = "Cordyceps_2020_02_20"
datestr = "2020-02-20"
window_length = 25

# get relevant parmfiles
sql = "SELECT parmfile, resppath, gDataRaw.id, cellid, gData.svalue, gData.value FROM gDataRaw INNER JOIN gData on (gDataRaw.id=gData.rawid and gData.name='Tar_Frequencies') WHERE runclass=%s and behavior='active' and parmfile like %s"
parmfiles1 = nd.pd_query(sql, ('BVT', filename + "%"))
parmfiles1['date'] = dt.datetime.strptime(datestr, '%Y-%m-%d')
parmfiles1 = parmfiles1.set_index('id')

sql = "SELECT parmfile, resppath, gDataRaw.id, cellid, gData.svalue FROM gDataRaw INNER JOIN gData on (gDataRaw.id=gData.rawid and gData.name='Behave_PumpDuration') WHERE runclass=%s and behavior='active' and parmfile like %s"
parmfiles = nd.pd_query(sql, ('BVT', filename + "%"))
parmfiles['date'] = dt.datetime.strptime(datestr, '%Y-%m-%d')
parmfiles = parmfiles.set_index('id')
parmfiles = parmfiles.rename(columns={'svalue': 'pumpdur'})

# now join on rawid
parmfiles = pd.concat([parmfiles['pumpdur'], parmfiles1], axis=1, join='outer')

parmfiles.loc[parmfiles.svalue.isnull(),
              'svalue'] = parmfiles.loc[parmfiles.svalue.isnull(),
                                        'value'].astype(str)
Esempio n. 16
0
from src.data.load import get_site_ids
from nems.db import pd_query
from warnings import warn

# get CPN sites
all_sites = get_site_ids(316)

region_map = dict()
for site in all_sites.keys():
    # pulls the region from celldb
    BF_querry = "select area from gCellMaster where siteid=%s"
    raw_area = pd_query(BF_querry, params=(site, )).iloc[0, 0]

    # Sanitizes region in case of missing values
    if raw_area is None:

        warn(f'site {site} has undefined region')
        print()
        continue
    else:
        area = raw_area.split(',')[0]
        if area == '':
            print(f'site {site} has undefined region')
            continue
        elif area not in ('A1', 'PEG'):
            print(f'site {site} has unrecognized region:{area}')
            continue

    region_map[site] = area
Esempio n. 17
0
def get_model_results_per_state_model(batch=307,
                                      state_list=None,
                                      loader="psth.fs20.pup-ld-",
                                      fitter="_jk.nf20-basic",
                                      basemodel="-ref-psthfr.s_sdexp.S"):
    """
    loader = "psth.fs20.pup-ld-"
    fitter = "_jk.nf20-basic"
    basemodel = "-ref-psthfr.s_sdexp.S"
    state_list = ['st.pup0.beh0','st.pup0.beh','st.pup.beh0','st.pup.beh']

    d=get_model_results_per_state_model(batch=307, state_list=state_list,
                                        loader=loader,fitter=fitter,
                                        basemodel=basemodel)

    state_list defaults to
       ['st.pup0.beh0','st.pup0.beh','st.pup.beh0','st.pup.beh']
    """

    if state_list is None:
        state_list = [
            'st.pup0.beh0', 'st.pup0.beh', 'st.pup.beh0', 'st.pup.beh'
        ]

    modelnames = [loader + s + basemodel + fitter for s in state_list]

    celldata = nd.get_batch_cells(batch=batch)
    cellids = celldata['cellid'].tolist()
    isolation = [
        nd.get_isolation(cellid=c, batch=batch).loc[0, 'min_isolation']
        for c in cellids
    ]

    if state_list[-1].endswith('fil') or state_list[-1].endswith('pas'):
        include_AP = True
    else:
        include_AP = False

    d = pd.DataFrame(columns=[
        'cellid', 'modelname', 'state_sig', 'state_chan', 'MI', 'isolation',
        'r', 'r_se', 'd', 'g', 'sp', 'state_chan_alt'
    ])

    new_sdexp = False
    for mod_i, m in enumerate(modelnames):
        print('Loading modelname: ', m)
        modelspecs = nems_db.params._get_modelspecs(cellids,
                                                    batch,
                                                    m,
                                                    multi='mean')

        for modelspec in modelspecs:
            meta = ms.get_modelspec_metadata(modelspec)
            phi = list(modelspec[0]['phi'].keys())
            c = meta['cellid']
            iso = isolation[cellids.index(c)]
            state_mod = meta['state_mod']
            state_mod_se = meta['se_state_mod']
            state_chans = meta['state_chans']
            if 'g' in phi:
                dc = modelspec[0]['phi']['d']
                gain = modelspec[0]['phi']['g']
            elif ('amplitude_g' in phi) & ('amplitude_d' in phi):
                new_sdexp = True
                dc = None
                gain = None
                g_amplitude = modelspec[0]['phi']['amplitude_g']
                g_base = modelspec[0]['phi']['base_g']
                g_kappa = modelspec[0]['phi']['kappa_g']
                g_offset = modelspec[0]['phi']['offset_g']
                d_amplitude = modelspec[0]['phi']['amplitude_d']
                d_base = modelspec[0]['phi']['base_d']
                d_kappa = modelspec[0]['phi']['kappa_d']
                d_offset = modelspec[0]['phi']['offset_d']

            gain_mod = None
            dc_mod = None
            if 'state_mod_gain' in meta.keys():
                gain_mod = meta['state_mod_gain']
                dc_mod = meta['state_mod_dc']

            if dc is not None:
                sp = modelspec[0]['phi'].get('sp', np.zeros(gain.shape))
                if dc.ndim > 1:
                    dc = dc[0, :]
                    gain = gain[0, :]
                    sp = sp[0, :]

            a_count = 0
            p_count = 0

            for j, sc in enumerate(state_chans):
                if gain is not None:
                    gain_val = gain[j]
                    dc_val = dc[j]
                    sp_val = sp[j]
                else:
                    gain_val = None
                    dc_val = None
                    sp_val = None
                r = {
                    'cellid': c,
                    'state_chan': sc,
                    'modelname': m,
                    'isolation': iso,
                    'state_sig': state_list[mod_i],
                    'g': gain_val,
                    'd': dc_val,
                    'sp': sp_val,
                    'MI': state_mod[j],
                    'r': meta['r_test'][0],
                    'r_se': meta['se_test'][0]
                }
                if new_sdexp:
                    r.update({
                        'g_amplitude': g_amplitude[0, j],
                        'g_base': g_base[0, j],
                        'g_kappa': g_kappa[0, j],
                        'g_offset': g_offset[0, j],
                        'd_amplitude': d_amplitude[0, j],
                        'd_base': d_base[0, j],
                        'd_kappa': d_kappa[0, j],
                        'd_offset': d_offset[0, j]
                    })
                if gain_mod is not None:
                    r.update({'gain_mod': gain_mod[j], 'dc_mod': dc_mod[j]})

                d = d.append(r, ignore_index=True)
                l = len(d) - 1

                if include_AP and sc.startswith("FILE_"):
                    siteid = c.split("-")[0]
                    fn = "%" + sc.replace("FILE_", "") + "%"
                    sql = "SELECT * FROM gDataRaw WHERE cellid=%s" +\
                       " AND parmfile like %s"
                    dcellfile = nd.pd_query(sql, (siteid, fn))
                    if dcellfile.loc[0]['behavior'] == 'active':
                        a_count += 1
                        d.loc[l,
                              'state_chan_alt'] = "ACTIVE_{}".format(a_count)
                    else:
                        p_count += 1
                        d.loc[l,
                              'state_chan_alt'] = "PASSIVE_{}".format(p_count)
                else:
                    d.loc[l, 'state_chan_alt'] = d.loc[l, 'state_chan']

    #d['r_unique'] = d['r'] - d['r0']
    #d['MI_unique'] = d['MI'] - d['MI0']

    return d
Esempio n. 18
0
def penetration_map(sites,
                    equal_aspect=False,
                    flip_X=False,
                    flatten=False,
                    flip_YZ=False,
                    landmarks=None):
    """
    Plots a 3d map of the list of specified sites, displaying the best frequency as color, and the brain region as
    maker type (NA: circle, A1: triangle, PEG: square).
    The site location, brain area and best frequency are extracted from celldb, specifically from the penetration (for
    coordinates and rotations) and the penetration (for area and best frequency) sites. If no coordinates are found the
    site is ignored. no BF is displayed as an empty marker.
    The values in the plot increase in direction Posterior -> Anterior, Medial -> Lateral and Ventral -> Dorsal. The
    antero posterior (X) axis can be flipped with the according parameter.
    :param sites: list of str specifying sites, with the format "ABC001a"
    :param equal_aspect: boolean. whether to constrain the data to a cubic/square space i.e. equal dimensions in XYZ/XY
    :flip_X: Boolean. Flips the direction labels for the antero-posterior (X) axis. The default is A > P .
    Y. Lateral > Medial, Z. Dorsal > Ventral.
    :flatten: Boolean. PCA 2d projection. Work in progress.
    :flip_YZ: Boolean. Flips the direction and labels of the YZ principal component when flattening.
    :landmarks: dict of vectors, where the key specifies the landmark name, and the vector has the values
    [x0, y0, z0, x, y, z, tilt, rot]. If the landmark name is 'OccCrest' or 'MidLine' uses the AP and ML values as zeros
    respectively.
    :return: matplotlib figure
    """
    area_marker = {'NA': 'o', 'A1': '^', 'PEG': 's'}

    coordinates = list()
    best_frequencies = list()
    areas = list()
    good_sites = list()

    # get values from cell db and transforms into coordinates
    for pp, site in enumerate(sites):
        # gets the penetrations MT coordinates
        penetration = site[0:6]
        coord_querry = "select ecoordinates from gPenetration where penname=%s"
        raw_coords = db.pd_query(coord_querry, params=(penetration, )).iloc[0,
                                                                            0]
        all_coords = np.asarray(raw_coords.split(',')).astype(float).squeeze()

        # Dimensions X:Antero-Posterior Y:Medio-Lateral Z:Dorso-Ventral
        MT_0 = all_coords[0:3]
        MT = all_coords[3:6]
        tilt = radians(all_coords[6])
        rotation = radians(all_coords[7])

        # rejects sites with no coordinates
        no_ref = np.all(MT_0 == 0)
        no_val = np.all(MT == 0)
        if no_ref or no_val:
            print(
                f'skipping penetration {penetration}, No coordinates specified'
            )
            continue

        # defines tilt and rotation matrices, tilt around X axis, rotation around Z axis
        tilt_mat = np.asarray([[1, 0, 0], [0, cos(tilt),
                                           sin(tilt)],
                               [0, -sin(tilt), cos(tilt)]])
        rot_mat = np.asarray([[cos(rotation), sin(rotation), 0],
                              [-sin(rotation),
                               cos(rotation), 0], [0, 0, 1]])

        # calculates the relative MT coordinates and rotates.
        MT_rel_rot = rot_mat @ tilt_mat @ (MT - MT_0)

        # get the first value of BF from cellDB TODO this is a temporary cludge, in the future, with all BFs set, a more elegant approach is required
        BF_querry = "select bf, area from gCellMaster where siteid=%s"
        try:
            raw_BF, raw_area = db.pd_query(BF_querry,
                                           params=(site, )).iloc[0, :]
        except:
            raw_BF = None
            raw_area = None
        # Sanitize best frequency in case of missing values
        if raw_BF is None:
            print(f'site {site} has undefined best frequency')
            BF = 0
        else:
            BF = int(raw_BF.split(',')[0])
            if BF == 0:
                print(f'site {site} has undefined best frequency')

        # Sanitizes region in case of missing values
        if raw_area is None:
            print(f'site {site} has undefined region')
            area = 'NA'
        else:
            area = raw_area.split(',')[0]
            if area == '':
                print(f'site {site} has undefined region')
                area = 'NA'
            elif area not in ('A1', 'PEG'):
                print(f'site {site} has unrecognized region: {area}')
                area = 'NA'

        coordinates.append(MT_rel_rot)
        best_frequencies.append(BF)
        areas.append(area)
        good_sites.append(site)

    # adds manual landmarks specified in dictionary
    if landmarks is not None:
        X0 = []
        Y0 = []
        for landname, all_coords in landmarks.items():
            all_coords = np.asarray(all_coords)
            MT_0 = all_coords[0:3]
            MT = all_coords[3:6]
            tilt = radians(all_coords[6])
            rotation = radians(all_coords[7])

            # defines tilt and rotation matrices, tilt around X axis, rotation around Z axis
            tilt_mat = np.asarray([[1, 0, 0], [0, cos(tilt),
                                               sin(tilt)],
                                   [0, -sin(tilt), cos(tilt)]])
            rot_mat = np.asarray([[cos(rotation),
                                   sin(rotation), 0],
                                  [-sin(rotation),
                                   cos(rotation), 0], [0, 0, 1]])

            # calculates the relative MT coordinates and rotates.
            MT_rel_rot = rot_mat @ tilt_mat @ (MT - MT_0)

            coordinates.append(MT_rel_rot)
            best_frequencies.append(0)
            areas.append('NA')
            # pads with spaces and holds 3 letter, for consistent naming with sites
            good_sites.append(f'   {landname[0:3]}')

            #saves values as zero reference if correct landname
            if landname == 'OccCrest': X0.append(MT_rel_rot[0])
            if landname == 'MidLine': Y0.append(MT_rel_rot[1])

    coordinates = np.stack(coordinates, axis=1)
    best_frequencies = np.asarray(best_frequencies)
    areas = np.asarray(areas)
    good_sites = np.asarray(good_sites)

    # centers data and transforms cm to mm
    center = np.mean(coordinates, axis=1)

    # uses landmarks as zero values if any
    if landmarks is not None:
        if X0: center[0] = X0[0]
        if Y0: center[1] = Y0[0]

    coordinates = coordinates - center[:, None]
    coordinates = coordinates * 10

    # defines BF colormap range if valid best frequencies are available.
    vmax = best_frequencies.max() if best_frequencies.max() > 0 else 32000
    vmin = best_frequencies[
        best_frequencies != 0].min() if best_frequencies.min() > 0 else 100

    if flatten is False:
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')

        if equal_aspect:
            X, Y, Z = coordinates
            # Create cubic bounding box to simulate equal aspect ratio
            max_range = np.array(
                [X.max() - X.min(),
                 Y.max() - Y.min(),
                 Z.max() - Z.min()]).max()
            Xb = 0.5 * max_range * np.mgrid[-1:2:2, -1:2:2, -1:2:2][0].flatten(
            ) + 0.5 * (X.max() + X.min())
            Yb = 0.5 * max_range * np.mgrid[-1:2:2, -1:2:2, -1:2:2][1].flatten(
            ) + 0.5 * (Y.max() + Y.min())
            Zb = 0.5 * max_range * np.mgrid[-1:2:2, -1:2:2, -1:2:2][2].flatten(
            ) + 0.5 * (Z.max() + Z.min())
            for xb, yb, zb in zip(Xb, Yb, Zb):
                ax.plot([xb], [yb], [zb], 'w')

        for area in set(areas):
            coord_subset = coordinates[:, areas == area]
            BF_subset = best_frequencies[areas == area]
            site_subset = good_sites[areas == area]

            X, Y, Z = coord_subset
            p = ax.scatter(X,
                           Y,
                           Z,
                           s=100,
                           marker=area_marker[area],
                           edgecolor='black',
                           c=BF_subset,
                           cmap='inferno',
                           norm=colors.LogNorm(vmin=vmin, vmax=vmax))

            for coord, site in zip(coord_subset.T, site_subset):
                x, y, z = coord
                ax.text(x, y, z, site[3:6])

        # formats axis
        ax.set_xlabel('anterior posterior (mm)')
        ax.set_ylabel('Medial Lateral (mm)')
        ax.set_zlabel('Dorsal ventral (mm)')

        fig.canvas.draw()
        x_tick_loc = ax.get_xticks().tolist()
        xlab = [f'{x:.1f}' for x in x_tick_loc]

        if flip_X:
            xlab[0] = 'A'
            xlab[-1] = 'P'
        else:
            xlab[0] = 'P'
            xlab[-1] = 'A'

        _ = ax.xaxis.set_major_locator(mticker.FixedLocator(x_tick_loc))
        _ = ax.set_xticklabels(xlab)

        y_tick_loc = ax.get_yticks().tolist()
        ylab = [f'{x:.1f}' for x in y_tick_loc]
        ylab[0] = 'M'
        ylab[-1] = 'L'
        _ = ax.yaxis.set_major_locator(mticker.FixedLocator(y_tick_loc))
        _ = ax.set_yticklabels(ylab)

        z_tick_loc = ax.get_zticks().tolist()
        zlab = [f'{x:.1f}' for x in z_tick_loc]
        zlab[0] = 'V'
        zlab[-1] = 'D'
        _ = ax.zaxis.set_major_locator(mticker.FixedLocator(z_tick_loc))
        _ = ax.set_zticklabels(zlab)

    elif flatten is True:
        # flattens doing a PCA over the Y and Z dimensions, i.e. medio-lateral and dorso-ventral.
        # this keeps the anteroposterior orientations to help locate the flattened projection
        pc1 = PCA().fit_transform(coordinates[1:, :].T)[:, 0]

        # uses midline zero as pc1 zero
        if landmarks is not None and 'MidLine' in landmarks.keys():
            pc1 = pc1 - pc1[np.argwhere(coordinates[1, :] == 0)].squeeze()

        flat_coords = np.stack((coordinates[0, :], pc1), axis=0)

        # FLips data on axes since 2d plots cannot be rotated interactively
        fx = -1 if flip_X is True else 1
        fyz = -1 if flip_YZ is True else 1
        flat_coords = flat_coords * np.array([[fx], [fyz]])

        fig, ax = plt.subplots()

        for area in set(areas):
            flat_coords_subset = flat_coords[:, areas == area]
            BF_subset = best_frequencies[areas == area]
            site_subset = good_sites[areas == area]

            X, Y = flat_coords_subset
            p = ax.scatter(X,
                           Y,
                           s=100,
                           marker=area_marker[area],
                           edgecolor='black',
                           c=BF_subset,
                           cmap='inferno',
                           norm=colors.LogNorm(vmin=vmin, vmax=vmax))

            for coord, site in zip(flat_coords_subset.T, site_subset):
                x, y = coord
                ax.text(x, y, site[3:6])

        # formats axis
        if equal_aspect:
            ax.axis('equal')
        ax.set_xlabel('anterior posterior (mm)')
        ax.set_ylabel('1PC_YZ (mm)')

    cbar = fig.colorbar(p)
    cbar.ax.set_ylabel('BF (Hz)', rotation=-90, va="top")

    return fig, coordinates
Esempio n. 19
0
def plot_collapsed_ref_tar(animal, site, cellids=None):
    site += "%"

    sql = "SELECT DISTINCT cellid, rawid, respfile FROM sCellFile WHERE cellid like %s AND runclassid=%s"
    d = nd.pd_query(sql, params=(
        site,
        42,
    ))

    mfile = []
    for f in np.unique(d.respfile):
        f_ = f.split('.')[0]
        mfile.append('/auto/data/daq/{0}/{1}/{2}'.format(
            animal, site[:-2], f_))

    if cellids is None:
        cellid = np.unique(d.cellid).tolist()
    else:
        cellid = cellids
    options = {
        "siteid": site[:-1],
        'cellid': cellids,
        "mfilename": mfile,
        'stim': False,
        'pupil': True,
        'rasterfs': 1000
    }

    uri = nb.baphy_load_recording_uri(**options)
    rec = Recording.load(uri)
    all_pupil = rec['pupil']._data
    ncols = len(mfile)
    if cellids is None:
        cellids = rec['resp'].chans

    for c in cellids:
        f, ax = plt.subplots(1, ncols, sharey=True, figsize=(12, 5))
        ref_base = 0
        for i, mf in enumerate(mfile):
            fn = mf.split('/')[-1]
            ep_mask = [ep for ep in rec.epochs.name if fn in ep]
            R = rec.copy()
            R['resp'] = R['resp'].rasterize(fs=20)
            R['resp'].fs = 20
            R = R.and_mask(ep_mask).apply_mask(reset_epochs=True)
            if '_a_' in fn:
                R = R.and_mask(['HIT_TRIAL']).apply_mask(reset_epochs=True)

            resp = R['resp'].extract_channels([c])

            tar_reps = resp.extract_epoch('TARGET').shape[0]
            tar_m = np.nanmean(resp.extract_epoch('TARGET'),
                               0).squeeze() * R['resp'].fs
            tar_sem = R['resp'].fs * np.nanstd(resp.extract_epoch('TARGET'),
                                               0).squeeze() / np.sqrt(tar_reps)

            ref_reps = resp.extract_epoch('REFERENCE').shape[0]
            ref_m = np.nanmean(resp.extract_epoch('REFERENCE'),
                               0).squeeze() * R['resp'].fs
            ref_sem = R['resp'].fs * np.nanstd(resp.extract_epoch('REFERENCE'),
                                               0).squeeze() / np.sqrt(ref_reps)

            # plot psth's
            time = np.linspace(0, len(tar_m) / R['resp'].fs, len(tar_m))
            ax[i].plot(time, tar_m, color='red', lw=2)
            ax[i].fill_between(time,
                               tar_m + tar_sem,
                               tar_m - tar_sem,
                               color='coral')
            time = np.linspace(0, len(ref_m) / R['resp'].fs, len(ref_m))
            ax[i].plot(time, ref_m, color='blue', lw=2)
            ax[i].fill_between(time,
                               ref_m + ref_sem,
                               ref_m - ref_sem,
                               color='lightblue')

            # set title
            ax[i].set_title(fn, fontsize=8)
            # set labels
            ax[i].set_xlabel('Time (s)')
            ax[i].set_ylabel('Spk / sec')

            # get raster plot baseline
            base = np.max(np.concatenate((tar_m + tar_sem, ref_m + ref_sem)))
            if base > ref_base:
                ref_base = base

        for i, mf in enumerate(mfile):
            # plot the rasters
            fn = mf.split('/')[-1]
            ep_mask = [ep for ep in rec.epochs.name if fn in ep]
            rast_rec = rec.and_mask(ep_mask).apply_mask(reset_epochs=True)
            if '_a_' in fn:
                rast_rec = rast_rec.and_mask(['HIT_TRIAL'
                                              ]).apply_mask(reset_epochs=True)
            rast = rast_rec['resp'].extract_channels([c])

            ref_times = np.where(rast.extract_epoch('REFERENCE').squeeze())
            base = ref_base
            ref_pupil = np.nanmean(
                rast_rec['pupil'].extract_epoch('REFERENCE'), -1)
            xoffset = rast.extract_epoch(
                'TARGET').shape[-1] / rec['resp'].fs + 0.01
            ax[i].plot(ref_pupil / np.max(all_pupil) + xoffset,
                       np.linspace(base, int(base * 2), len(ref_pupil)),
                       color='k')
            ax[i].axvline(xoffset + np.median(all_pupil / np.max(all_pupil)),
                          linestyle='--',
                          color='lightgray')
            #import pdb; pdb.set_trace()
            if ref_times[0].size != 0:
                max_rep = ref_pupil.shape[0] - 1
                ref_locs = ref_times[0] * (base / max_rep)
                ref_locs = ref_locs + base
                ax[i].plot(ref_times[1] / rec['resp'].fs,
                           ref_locs,
                           '|',
                           color='blue',
                           markersize=1)

            tar_times = np.where(rast.extract_epoch('TARGET').squeeze())
            tar_pupil = np.nanmean(rast_rec['pupil'].extract_epoch('TARGET'),
                                   -1)
            tar_base = np.max(ref_locs) + 1
            ax[i].plot(tar_pupil / np.max(all_pupil) + xoffset,
                       np.linspace(tar_base, tar_base + base, len(tar_pupil)),
                       color='k')
            if tar_times[0].size != 0:
                max_rep = tar_pupil.shape[0] - 1
                tar_locs = tar_times[0] * (base / max_rep)
                tar_locs = tar_locs + tar_base
                ax[i].plot(tar_times[1] / rec['resp'].fs,
                           tar_locs,
                           '|',
                           color='red',
                           markersize=1)

            # set ylim
            #ax[i].set_ylim((0, top))

            # set plot aspect
            #asp = np.diff(ax[i].get_xlim())[0] / np.diff(ax[i].get_ylim())[0]
            #ax[i].set_aspect(asp / 2)

        f.suptitle(c, fontsize=8)
        f.tight_layout()
Esempio n. 20
0
    120201, 120207, 120209, 120211, 120214, 120234, 120234, 120254, 120256,
    120258, 120260, 120272, 120273, 120274, 120275, 120293, 120283, 120285,
    120286, 120289, 120290, 120293, 120310, 120311, 120312, 120313, 120314,
    120316, 120317, 120435, 120436, 120437
]

# ================================= batch 307 ====================================
perfile_df = pd.read_csv(os.path.join(fpath, str(307), 'd_pup_fil_sdexp.csv'),
                         index_col=0)
df_307 = pd.DataFrame()
cells_307 = nd.get_batch_cells(307).cellid
for cellid in cells_307:
    _, rawid = nd.get_stable_batch_cells(batch=307, cellid=cellid)
    sql = "SELECT value, svalue, rawid from gData where name='Trial_TargetIdxFreq' and rawid in {}".format(
        tuple(rawid))
    d = nd.pd_query(sql, params=())
    sql = "SELECT value, svalue, rawid from gData where name='Trial_RelativeTarRefdB' and rawid in {}".format(
        tuple(rawid))
    d2 = nd.pd_query(sql, params=())
    sql = "SELECT behavior, id from gDataRaw where id in {0}".format(
        tuple(rawid))
    da = nd.pd_query(sql)

    d = d[d.rawid.isin(
        [r for r in da.id if da[da.id == r]['behavior'].values == 'active'])]
    d2 = d2[d2.rawid.isin(
        [r for r in da.id if da[da.id == r]['behavior'].values == 'active'])]
    d2.columns = [c + '_rel' for c in d2.columns]
    d = pd.concat([d, d2], axis=1)

    pf_labels = np.unique([
Esempio n. 21
0
#cellid="BRT038b-30-1"
#modelname="ozgf.fs100.ch18-ld-sev_dlog-wc.18x3-fir.1x15x3-relu.3-wc.3x1-lvl.1-dexp.1_tf.n.rb5"
#modelname="ozgf.fs100.ch18-ld-sev_dlog-wc.18x16.g-fir.4x15x4-relu.4-wc.4x1-lvl.1-dexp.1_tf.n.rb5"
#modelname='ozgf.fs100.ch18-ld-sev_dlog-wc.18x12-fir.4x15x3-relu.3-wc.3x1-lvl.1-dexp.1_tf.n.rb5'

cellid="BRT033b-12-3"
batch=308
modelname="ozgf.fs100.ch18-ld-sev_dlog-wc.18x16-fir.4x15x4-relu.4-wc.4x1-lvl.1-dexp.1_tf.n.rb5"
# modelname="ozgf.fs100.ch18-ld-sev_dlog-wc.18x2.g-fir.2x15-relu.1-lvl.1-dexp.1_tf.n.rb5"

xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname)
ex = gui.browse_xform_fit(ctx, xfspec)

########################################################
########################################################
#########################################################
########################################################


import nems.db as nd

batch=308

performance_data = nd.pd_query("SELECT modelname,cellid,r_test,r_floor FROM Results WHERE batch={}".format(batch))
performance_data['significant']=performance_data['r_test'] > 2*performance_data['r_floor']
performance_data['zero_test']=(performance_data['r_test']==0)

performance_data.groupby(['modelname'])['r_test','significant','zero_test'].mean()
performance_data.groupby(['modelname'])['zero_test'].mean()

Esempio n. 22
0
"""
Summary plots of r_test, gain vs. DC, and MI.
"""

import nems.db as nd

import matplotlib.pyplot as plt
import pandas as pd

batches = [289, 294, 323]
modelnames = ['ns.fs4.pup-ld-st.pup-hrc-psthfr_sdexp.SxR.bound_jk.nf10-basic', 'ns.fs4.pup-ld-st.pup0-hrc-psthfr_sdexp.SxR.bound_jk.nf10-basic',
        'ns.fs4.pup.voc-ld-st.pup-hrc-psthfr_sdexp.SxR.bound_jk.nf10-basic', 'ns.fs4.pup.voc-ld-st.pup0-hrc-psthfr_sdexp.SxR.bound_jk.nf10-basic']
sql = "SELECT r_test, se_test, cellid, modelname, batch from Results WHERE modelname in {0} and batch in {1}".format(tuple(modelnames), tuple(batches))

results = nd.pd_query(sql)
results['state_mod'] = ['st.pup' if 'st.pup0' not in s else 'st.pup0' for s in results['modelname']]

results = results[results.batch==323]

r = results.pivot(columns='state_mod', index='cellid')
rdiff = r.loc[:, pd.IndexSlice['r_test', 'st.pup']] - r.loc[:, pd.IndexSlice['r_test', 'st.pup0']]
se = r.loc[:, pd.IndexSlice['se_test', 'st.pup']] + r.loc[:, pd.IndexSlice['se_test', 'st.pup0']]
sig_mask = rdiff > se

f, ax = plt.subplots(1, 1, figsize=(6, 6))

ax.scatter(r.loc[:, pd.IndexSlice['r_test', 'st.pup0']] ** 2,
           r.loc[:, pd.IndexSlice['r_test', 'st.pup']] ** 2, 
           marker='o', edgecolor='white', color='grey')
ax.scatter(r.loc[sig_mask, pd.IndexSlice['r_test', 'st.pup0']] ** 2,
           r.loc[sig_mask, pd.IndexSlice['r_test', 'st.pup']] ** 2, 
Esempio n. 23
0
import numpy as np
from scipy.signal import convolve2d, butter, sosfilt

from nems import db
from nems.utils import smooth
from nems_lbhb.xform_wrappers import generate_recording_uri
from nems_lbhb.baphy_experiment import BAPHYExperiment
from nems_lbhb.baphy_io import load_continuous_openephys
from nems_lbhb.plots import plot_waveforms_64D

USE_DB = False

if USE_DB:
    expt_name = "TNC020a11_p_BNB"
    expt_name = "TNC017a10_p_BNB"
    dparm = db.pd_query(
        f"SELECT * FROM gDataRaw where parmfile like '{expt_name}%'")
    parmfile = dparm.resppath[0] + dparm.parmfile[0]
else:
    #hard-code path to parmfile
    #parmfile = "/auto/data/daq/Teonancatl/TNC018/TNC018a16_p_BNB.m"
    #parmfile = "/auto/data/daq/Teonancatl/TNC020/TNC020a11_p_BNB.m"
    parmfile = "/auto/data/daq/Teonancatl/TNC017/TNC017a03_p_BNB.m"
    parmfile = "/auto/data/daq/Teonancatl/TNC017/TNC017a10_p_BNB.m"
    parmfile = "/auto/data/daq/Teonancatl/TNC016/TNC016a03_p_BNB.m"
    parmfile = "/auto/data/daq/Teonancatl/TNC018/TNC018a03_p_BNB.m"
    parmfile = "/auto/data/daq/Tartufo/TAR010/TAR010a03_p_BNB.m"
    parmfile = "/auto/data/daq/Teonancatl/TNC006/TNC006a03_p_BNB.m"
    parmfile = "/auto/data/daq/Teonancatl/TNC006/TNC006a19_p_BNB.m"

## load the recording
parmfile = "/auto/data/daq/Tartufo/TAR010/TAR010a03_p_BNB.m"
Esempio n. 24
0
def calc_psth_metrics(batch, cellid, parmfile=None, paths=None):
    start_win_offset = 0  # Time (in sec) to offset the start of the window used to calculate threshold, exitatory percentage, and inhibitory percentage
    if parmfile:
        manager = BAPHYExperiment(parmfile)
    else:
        manager = BAPHYExperiment(cellid=cellid, batch=batch)

    options = ohel.get_load_options(
        batch)  #gets options that will include gtgram if batch=339
    rec = manager.get_recording(**options)

    area_df = db.pd_query(
        f"SELECT DISTINCT area FROM sCellFile where cellid like '{manager.siteid}%%'"
    )
    area = area_df.area.iloc[0]

    if rec['resp'].epochs[rec['resp'].epochs['name'] ==
                          'PASSIVE_EXPERIMENT'].shape[0] >= 2:
        rec = ohel.remove_olp_test(rec)

    rec['resp'] = rec['resp'].extract_channels([cellid])
    resp = copy.copy(rec['resp'].rasterize())
    rec['resp'].fs = 100

    norm_spont, SR, STD = ohel.remove_spont_rate_std(resp)
    params = ohel.get_expt_params(resp, manager, cellid)

    epcs = rec['resp'].epochs[rec['resp'].epochs['name'] ==
                              'PreStimSilence'].copy()
    ep2 = rec['resp'].epochs[rec['resp'].epochs['name'] ==
                             'PostStimSilence'].iloc[0].copy()
    params['prestim'], params['poststim'] = epcs.iloc[0][
        'end'], ep2['end'] - ep2['start']
    params['lenstim'] = ep2['end']

    stim_epochs = ep.epoch_names_matching(resp.epochs, 'STIM_')

    if paths and cellid[:3] == 'TBR':
        print(f"Deprecated, run on {cellid} though...")
        stim_epochs, rec, resp = ohel.path_tabor_get_epochs(
            stim_epochs, rec, resp, params)

    epoch_repetitions = [resp.count_epoch(cc) for cc in stim_epochs]
    full_resp = np.empty((max(epoch_repetitions), len(stim_epochs),
                          (int(params['lenstim']) * rec['resp'].fs)))
    full_resp[:] = np.nan
    for cnt, epo in enumerate(stim_epochs):
        resps_list = resp.extract_epoch(epo)
        full_resp[:resps_list.shape[0], cnt, :] = resps_list[:, 0, :]

    #Calculate a few metrics
    corcoef = ohel.calc_base_reliability(full_resp)
    avg_resp = ohel.calc_average_response(full_resp, params)
    snr = compute_snr(resp)

    #Grab and label epochs that have two sounds in them (no null)
    presil, postsil = int(params['prestim'] * rec['resp'].fs), int(
        params['poststim'] * rec['resp'].fs)
    twostims = resp.epochs[resp.epochs['name'].str.count('-0-1') == 2].copy()
    ep_twostim = twostims.name.unique().tolist()
    ep_twostim.sort()

    ep_names = resp.epochs[resp.epochs['name'].str.contains('STIM_')].copy()
    ep_names = ep_names.name.unique().tolist()
    ep_types = list(map(ohel.label_ep_type, ep_names))
    ep_df = pd.DataFrame({'name': ep_names, 'type': ep_types})

    cell_df = []
    for cnt, stimmy in enumerate(ep_twostim):
        kind = ohel.label_pair_type(stimmy)
        seps = (stimmy.split('_')[1], stimmy.split('_')[2])
        BG, FG = seps[0].split('-')[0][2:], seps[1].split('-')[0][2:]

        Aepo, Bepo = 'STIM_' + seps[0] + '_null', 'STIM_null_' + seps[1]

        rAB = resp.extract_epoch(stimmy)
        rA, rB = resp.extract_epoch(Aepo), resp.extract_epoch(Bepo)

        fn = lambda x: np.atleast_2d(sp.smooth(x.squeeze(), 3, 2) - SR)
        rAsm = np.squeeze(np.apply_along_axis(fn, 2, rA))
        rBsm = np.squeeze(np.apply_along_axis(fn, 2, rB))
        rABsm = np.squeeze(np.apply_along_axis(fn, 2, rAB))

        rA_st, rB_st = rAsm[:, presil:-postsil], rBsm[:, presil:-postsil]
        rAB_st = rABsm[:, presil:-postsil]

        rAm, rBm = np.nanmean(rAsm, axis=0), np.nanmean(rBsm, axis=0)
        rABm = np.nanmean(rABsm, axis=0)

        AcorAB = np.corrcoef(
            rAm, rABm)[0, 1]  # Corr between resp to A and resp to dual
        BcorAB = np.corrcoef(
            rBm, rABm)[0, 1]  # Corr between resp to B and resp to dual

        A_FR, B_FR, AB_FR = np.nanmean(rA_st), np.nanmean(rB_st), np.nanmean(
            rAB_st)

        min_rep = np.min(
            (rA.shape[0],
             rB.shape[0]))  #only will do something if SoundRepeats==Yes
        lin_resp = np.nanmean(rAsm[:min_rep, :] + rBsm[:min_rep, :], axis=0)
        supp = np.nanmean(lin_resp - AB_FR)

        AcorLin = np.corrcoef(
            rAm, lin_resp)[0, 1]  # Corr between resp to A and resp to lin
        BcorLin = np.corrcoef(
            rBm, lin_resp)[0, 1]  # Corr between resp to B and resp to lin

        Apref, Bpref = AcorAB - AcorLin, BcorAB - BcorLin
        pref = Apref - Bpref

        # if params['Binaural'] == 'Yes':
        #     dA, dB = ohel.get_binaural_adjacent_epochs(stimmy)
        #
        #     rdA, rdB = resp.extract_epoch(dA), resp.extract_epoch(dB)
        #     rdAm = np.nanmean(np.squeeze(np.apply_along_axis(fn, 2, rdA))[:, presil:-postsil], axis=0)
        #     rdBm = np.nanmean(np.squeeze(np.apply_along_axis(fn, 2, rdB))[:, presil:-postsil], axis=0)
        #
        #     ABcordA = np.corrcoef(rABm, rdAm)[0, 1]  # Corr between resp to AB and resp to BG swap
        #     ABcordB = np.corrcoef(rABm, rdBm)[0, 1]  # Corr between resp to AB and resp to FG swap

        cell_df.append({
            'epoch': stimmy,
            'kind': kind,
            'BG': BG,
            'FG': FG,
            'AcorAB': AcorAB,
            'BcorAB': BcorAB,
            'AcorLin': AcorLin,
            'BcorLin': BcorLin,
            'Apref': Apref,
            'Bpref': Bpref,
            'pref': pref,
            'combo_FR': AB_FR,
            'bg_FR': A_FR,
            'fg_FR': B_FR,
            'supp': supp
        })

    cell_df = pd.DataFrame(cell_df)
    cell_df['SR'], cell_df['STD'] = SR, STD
    # cell_df['corcoef'], cell_df['avg_resp'], cell_df['snr'] = corcoef, avg_resp, snr
    cell_df.insert(loc=0, column='area', value=area)

    return cell_df

    # COMPUTE ALL FOLLOWING metrics using smoothed driven rate
    # est, val = rec.split_using_epoch_occurrence_counts(rec,epoch_regex='^STIM_')
    val = rec.copy()
    val['resp'] = val['resp'].rasterize()
    val['stim'] = val['stim'].rasterize()
    val = preproc.average_away_epoch_occurrences(val, epoch_regex='^STIM_')

    # smooth and subtract SR
    fn = lambda x: np.atleast_2d(sp.smooth(x.squeeze(), 3, 2) - SR)
    val['resp'] = val['resp'].transform(fn)
    val['resp'] = ohel.add_stimtype_epochs(val['resp'])

    if val['resp'].count_epoch('REFERENCE'):
        epochname = 'REFERENCE'
    else:
        epochname = 'TRIAL'
    sts = val['resp'].epochs['start'].copy()
    nds = val['resp'].epochs['end'].copy()
    sts_rec = rec['resp'].epochs['start'].copy()
    val['resp'].epochs['end'] = val['resp'].epochs['start'] + params['prestim']
    ps = val['resp'].select_epochs([epochname]).as_continuous()
    ff = np.isfinite(ps)
    SR_av = ps[ff].mean() * resp.fs
    SR_av_std = ps[ff].std() * resp.fs

    # Compute max over single-voice trials
    val['resp'].epochs['end'] = nds
    val['resp'].epochs['start'] = sts
    val['resp'].epochs[
        'start'] = val['resp'].epochs['start'] + params['prestim']
    TotalMax = np.nanmax(val['resp'].as_continuous())
    ps = np.hstack((val['resp'].extract_epoch('10').flatten(),
                    val['resp'].extract_epoch('01').flatten()))
    SinglesMax = np.nanmax(ps)

    # Compute threshold, exitatory percentage, and inhibitory percentage
    prestim, poststim = params['prestim'], params['poststim']
    val['resp'].epochs['end'] = nds
    val['resp'].epochs['start'] = sts
    val['resp'].epochs[
        'start'] = val['resp'].epochs['start'] + prestim + start_win_offset
    val['resp'].epochs['end'] = val['resp'].epochs['end'] - poststim
    thresh = np.array(((SR + SR_av_std) / resp.fs, (SR - SR_av_std) / resp.fs))
    thresh = np.array((SR / resp.fs + 0.1 * (SinglesMax - SR / resp.fs),
                       (SR - SR_av_std) / resp.fs))
    # SR/resp.fs - 0.5 * (np.nanmax(val['resp'].as_continuous()) - SR/resp.fs)]

    types = ['10', '01', '20', '02', '11', '12', '21', '22']
    excitatory_percentage = {}
    inhibitory_percentage = {}
    Max = {}
    Mean = {}
    for _type in types:
        if _type in val['resp'].epochs.name.values:
            ps = val['resp'].extract_epoch(_type).flatten()
            ff = np.isfinite(ps)
            excitatory_percentage[_type] = (ps[ff] >
                                            thresh[0]).sum() / ff.sum()
            inhibitory_percentage[_type] = (ps[ff] <
                                            thresh[1]).sum() / ff.sum()
            Max[_type] = ps[ff].max() / SinglesMax
            Mean[_type] = ps[ff].mean()

    # Compute threshold, exitatory percentage, and inhibitory percentage just over onset time
    # restore times
    val['resp'].epochs['end'] = nds
    val['resp'].epochs['start'] = sts
    # Change epochs to stimulus onset times
    val['resp'].epochs['start'] = val['resp'].epochs['start'] + prestim
    val['resp'].epochs['end'] = val['resp'].epochs['start'] + prestim + .5
    excitatory_percentage_onset = {}
    inhibitory_percentage_onset = {}
    Max_onset = {}
    for _type in types:
        ps = val['resp'].extract_epoch(_type).flatten()
        ff = np.isfinite(ps)
        excitatory_percentage_onset[_type] = (ps[ff] >
                                              thresh[0]).sum() / ff.sum()
        inhibitory_percentage_onset[_type] = (ps[ff] <
                                              thresh[1]).sum() / ff.sum()
        Max_onset[_type] = ps[ff].max() / SinglesMax

        # find correlations between double and single-voice responses
    val['resp'].epochs['end'] = nds
    val['resp'].epochs['start'] = sts
    val['resp'].epochs['start'] = val['resp'].epochs['start'] + prestim
    rec['resp'].epochs['start'] = rec['resp'].epochs['start'] + prestim
    # over stim on time to end + 0.5
    val['linmodel'] = val['resp'].copy()
    val['linmodel']._data = np.full(val['linmodel']._data.shape, np.nan)
    types = ['11', '12', '21', '22']
    epcs = val['resp'].epochs[val['resp'].epochs['name'].str.contains(
        'STIM')].copy()
    epcs['type'] = epcs['name'].apply(ohel.label_ep_type)
    names = [[n.split('_')[1], n.split('_')[2]] for n in epcs['name']]
    EA = np.array([n[0] for n in names])
    EB = np.array([n[1] for n in names])

    r_dual_B, r_dual_A, r_dual_B_nc, r_dual_A_nc = {}, {}, {}, {}
    r_dual_B_bal, r_dual_A_bal = {}, {}
    r_lin_B, r_lin_A, r_lin_B_nc, r_lin_A_nc = {}, {}, {}, {}
    r_lin_B_bal, r_lin_A_bal = {}, {}

    N_ac = 200
    full_resp = rec['resp'].rasterize()
    full_resp = full_resp.transform(fn)
    for _type in types:
        inds = np.nonzero(epcs['type'].values == _type)[0]
        rA_st, rB_st, r_st, rA_rB_st = [], [], [], []
        init = True
        for ind in inds:
            # for each dual-voice response
            r = val['resp'].extract_epoch(epcs.iloc[ind]['name'])
            if np.any(np.isfinite(r)):
                print(epcs.iloc[ind]['name'])
                # Find the indicies of single-voice responses that match this dual-voice response
                indA = np.where((EA[ind] == EA) & (EB == 'null'))[0]
                indB = np.where((EB[ind] == EB) & (EA == 'null'))[0]
                if (len(indA) > 0) & (len(indB) > 0):
                    # from pdb import set_trace
                    # set_trace()
                    rA = val['resp'].extract_epoch(epcs.iloc[indA[0]]['name'])
                    rB = val['resp'].extract_epoch(epcs.iloc[indB[0]]['name'])
                    r_st.append(
                        full_resp.extract_epoch(epcs.iloc[ind]['name'])[:,
                                                                        0, :])
                    rA_st_ = full_resp.extract_epoch(
                        epcs.iloc[indA[0]]['name'])[:, 0, :]
                    rB_st_ = full_resp.extract_epoch(
                        epcs.iloc[indB[0]]['name'])[:, 0, :]
                    rA_st.append(rA_st_)
                    rB_st.append(rB_st_)
                    minreps = np.min((rA_st_.shape[0], rB_st_.shape[0]))
                    rA_rB_st.append(rA_st_[:minreps, :] + rB_st_[:minreps, :])
                    if init:
                        rA_ = rA.squeeze()
                        rB_ = rB.squeeze()
                        r_ = r.squeeze()
                        rA_rB_ = rA.squeeze() + rB.squeeze()
                        init = False
                    else:
                        rA_ = np.hstack((rA_, rA.squeeze()))
                        rB_ = np.hstack((rB_, rB.squeeze()))
                        r_ = np.hstack((r_, r.squeeze()))
                        rA_rB_ = np.hstack(
                            (rA_rB_, rA.squeeze() + rB.squeeze()))
                    val['linmodel'] = val['linmodel'].replace_epoch(
                        epcs.iloc[ind]['name'], rA + rB, preserve_nan=False)
        ff = np.isfinite(r_) & np.isfinite(rA_) & np.isfinite(
            rB_)  # find places with data
        r_dual_A[_type] = np.corrcoef(rA_[ff], r_[ff])[
            0, 1]  # Correlation between response to A and response to dual
        r_dual_B[_type] = np.corrcoef(rB_[ff], r_[ff])[
            0, 1]  # Correlation between response to B and response to dual
        r_lin_A[_type] = np.corrcoef(
            rA_[ff], rA_rB_[ff]
        )[0,
          1]  # Correlation between response to A and response to linear 'model'
        r_lin_B[_type] = np.corrcoef(
            rB_[ff], rA_rB_[ff]
        )[0,
          1]  # Correlation between response to B and response to linear 'model'

        # correlations over single-trial data
        minreps = np.min([x.shape[0] for x in r_st])
        r_st = [x[:minreps, :] for x in r_st]
        r_st = np.concatenate(r_st, axis=1)
        rA_st = [x[:minreps, :] for x in rA_st]
        rA_st = np.concatenate(rA_st, axis=1)
        rB_st = [x[:minreps, :] for x in rB_st]
        rB_st = np.concatenate(rB_st, axis=1)
        rA_rB_st = [x[:minreps, :] for x in rA_rB_st]
        rA_rB_st = np.concatenate(rA_rB_st, axis=1)

        r_lin_A_bal[_type] = np.corrcoef(rA_st[0::2, ff].mean(axis=0),
                                         rA_rB_st[1::2, ff].mean(axis=0))[0, 1]
        r_lin_B_bal[_type] = np.corrcoef(rB_st[0::2, ff].mean(axis=0),
                                         rA_rB_st[1::2, ff].mean(axis=0))[0, 1]
        r_dual_A_bal[_type] = np.corrcoef(rA_st[0::2, ff].mean(axis=0),
                                          r_st[:, ff].mean(axis=0))[0, 1]
        r_dual_B_bal[_type] = np.corrcoef(rB_st[0::2, ff].mean(axis=0),
                                          r_st[:, ff].mean(axis=0))[0, 1]

        r_dual_A_nc[_type] = ohel.r_noise_corrected(rA_st, r_st)
        r_dual_B_nc[_type] = ohel.r_noise_corrected(rB_st, r_st)
        r_lin_A_nc[_type] = ohel.r_noise_corrected(rA_st, rA_rB_st)
        r_lin_B_nc[_type] = ohel.r_noise_corrected(rB_st, rA_rB_st)

        if _type == '11':
            r11 = nems.metrics.corrcoef._r_single(r_st, 200, 0)
        elif _type == '12':
            r12 = nems.metrics.corrcoef._r_single(r_st, 200, 0)
        elif _type == '21':
            r21 = nems.metrics.corrcoef._r_single(r_st, 200, 0)
        elif _type == '22':
            r22 = nems.metrics.corrcoef._r_single(r_st, 200, 0)
        # rac = _r_single(X, N)
        # r_ceiling = [nmet.r_ceiling(p, rec, 'pred', 'resp') for p in val_copy]

    # Things that used to happen only for _type is 'C' but still seem valid
    r_A_B = np.corrcoef(rA_[ff], rB_[ff])[0, 1]
    r_A_B_nc = r_noise_corrected(rA_st, rB_st)
    rAA = nems.metrics.corrcoef._r_single(rA_st, 200, 0)
    rBB = nems.metrics.corrcoef._r_single(rB_st, 200, 0)
    Np = 0
    rAA_nc = np.zeros(Np)
    rBB_nc = np.zeros(Np)
    hv = int(minreps / 2)
    for i in range(Np):
        inds = np.random.permutation(minreps)
        rAA_nc[i] = sp.r_noise_corrected(rA_st[inds[:hv]], rA_st[inds[hv:]])
        rBB_nc[i] = sp.r_noise_corrected(rB_st[inds[:hv]], rB_st[inds[hv:]])
    ffA = np.isfinite(rAA_nc)
    ffB = np.isfinite(rBB_nc)
    rAAm = rAA_nc[ffA].mean()
    rBBm = rBB_nc[ffB].mean()
    mean_nsA = rA_st.sum(axis=1).mean()
    mean_nsB = rB_st.sum(axis=1).mean()
    min_nsA = rA_st.sum(axis=1).min()
    min_nsB = rB_st.sum(axis=1).min()

    # Calculate correlation between linear 'model and dual-voice response, and mean amount of suppression, enhancement relative to linear 'model'
    r_fit_linmodel = {}
    r_fit_linmodel_NM = {}
    r_ceil_linmodel = {}
    mean_enh = {}
    mean_supp = {}
    EnhP = {}
    SuppP = {}
    DualAboveZeroP = {}
    resp_ = copy.deepcopy(rec['resp'].rasterize())
    resp_.epochs['start'] = sts_rec
    fn = lambda x: np.atleast_2d(
        sp.smooth(x.squeeze(), 3, 2) - SR / val['resp'].fs)
    resp_ = resp_.transform(fn)
    for _type in types:
        val_copy = copy.deepcopy(val)
        #        from pdb import set_trace
        #        set_trace()
        val_copy['resp'] = val_copy['resp'].select_epochs([_type])
        # Correlation between linear 'model' (response to A plus response to B) and dual-voice response
        r_fit_linmodel_NM[_type] = nmet.corrcoef(val_copy, 'linmodel', 'resp')
        # r_ceil_linmodel[_type] = nems.metrics.corrcoef.r_ceiling(val_copy,rec,'linmodel', 'resp',exclude_neg_pred=False)[0]
        # Noise-corrected correlation between linear 'model' (response to A plus response to B) and dual-voice response
        r_ceil_linmodel[_type] = nems.metrics.corrcoef.r_ceiling(
            val_copy, rec, 'linmodel', 'resp')[0]

        pred = val_copy['linmodel'].as_continuous()
        resp = val_copy['resp'].as_continuous()
        ff = np.isfinite(pred) & np.isfinite(resp)
        # cc = np.corrcoef(sp.smooth(pred[ff],3,2), sp.smooth(resp[ff],3,2))
        cc = np.corrcoef(pred[ff], resp[ff])
        r_fit_linmodel[_type] = cc[0, 1]

        prdiff = resp[ff] - pred[ff]
        mean_enh[_type] = prdiff[prdiff > 0].mean() * val['resp'].fs
        mean_supp[_type] = prdiff[prdiff < 0].mean() * val['resp'].fs

        # Find percent of time response is suppressed vs enhanced relative to what would be expected by a linear sum of single-voice responses
        # First, jacknife to find...
    #        Njk=10
    #        if _type is 'C':
    #            stims=['STIM_T+si464+si464','STIM_T+si516+si516']
    #        else:
    #            stims=['STIM_T+si464+si516', 'STIM_T+si516+si464']
    #        T=int(700+prestim*val['resp'].fs)
    #        Tps=int(prestim*val['resp'].fs)
    #        jns=np.zeros((Njk,T,len(stims)))
    #        for ns in range(len(stims)):
    #            for njk in range(Njk):
    #                resp_jn=resp_.jackknife_by_epoch(Njk,njk,stims[ns])
    #                jns[njk,:,ns]=np.nanmean(resp_jn.extract_epoch(stims[ns]),axis=0)
    #        jns=np.reshape(jns[:,Tps:,:],(Njk,700*len(stims)),order='F')
    #
    #        lim_models=np.zeros((700,len(stims)))
    #        for ns in range(len(stims)):
    #            lim_models[:,ns]=val_copy['linmodel'].extract_epoch(stims[ns])
    #        lim_models=lim_models.reshape(700*len(stims),order='F')
    #
    #        ff=np.isfinite(lim_models)
    #        mean_diff=(jns[:,ff]-lim_models[ff]).mean(axis=0)
    #        std_diff=(jns[:,ff]-lim_models[ff]).std(axis=0)
    #        serr_diff=np.sqrt(Njk/(Njk-1))*std_diff
    #
    #        thresh=3
    #        dual_above_zero = (jns[:,ff].mean(axis=0) > std_diff)
    #        sig_enh = ((mean_diff/serr_diff) > thresh) & dual_above_zero
    #        sig_supp = ((mean_diff/serr_diff) < -thresh)
    #        DualAboveZeroP[_type] = (dual_above_zero).sum()/len(mean_diff)
    #        EnhP[_type] = (sig_enh).sum()/len(mean_diff)
    #        SuppP[_type] = (sig_supp).sum()/len(mean_diff)

    #        time = np.arange(0, lim_models.shape[0])/ val['resp'].fs
    #        plt.figure();
    #        plt.plot(time,jns.mean(axis=0),'.-k');
    #        plt.plot(time,lim_models,'.-g');
    #        plt.plot(time[sig_enh],lim_models[sig_enh],'.r')
    #        plt.plot(time[sig_supp],lim_models[sig_supp],'.b')
    #        plt.title('Type:{:s}, Enh:{:.2f}, Sup:{:.2f}, Resp_above_zero:{:.2f}'.format(_type,EnhP[_type],SuppP[_type],DualAboveZeroP[_type]))
    #        from pdb import set_trace
    #        set_trace()
    #        a=2
    # thrsh=5
    #        EnhP[_type] = ((prdiff*val['resp'].fs) > thresh).sum()/len(prdiff)
    #        SuppP[_type] = ((prdiff*val['resp'].fs) < -thresh).sum()/len(prdiff)
    #    return val
    #    return {'excitatory_percentage':excitatory_percentage,
    #            'inhibitory_percentage':inhibitory_percentage,
    #            'r_fit_linmodel':r_fit_linmodel,
    #            'SR':SR, 'SR_std':SR_std, 'SR_av_std':SR_av_std}
    #
    return {
        'thresh': thresh * val['resp'].fs,
        'EP_A': excitatory_percentage['A'],
        'EP_B': excitatory_percentage['B'],
        #            'EP_C':excitatory_percentage['C'],
        'EP_I': excitatory_percentage['I'],
        'IP_A': inhibitory_percentage['A'],
        'IP_B': inhibitory_percentage['B'],
        #            'IP_C':inhibitory_percentage['C'],
        'IP_I': inhibitory_percentage['I'],
        'OEP_A': excitatory_percentage_onset['A'],
        'OEP_B': excitatory_percentage_onset['B'],
        #            'OEP_C':excitatory_percentage_onset['C'],
        'OEP_I': excitatory_percentage_onset['I'],
        'OIP_A': inhibitory_percentage_onset['A'],
        'OIP_B': inhibitory_percentage_onset['B'],
        #            'OIP_C':inhibitory_percentage_onset['C'],
        'OIP_I': inhibitory_percentage_onset['I'],
        'Max_A': Max['A'],
        'Max_B': Max['B'],
        #            'Max_C':Max['C'],
        'Max_I': Max['I'],
        'Mean_A': Mean['A'],
        'Mean_B': Mean['B'],
        #            'Mean_C':Mean['C'],
        'Mean_I': Mean['I'],
        'OMax_A': Max_onset['A'],
        'OMax_B': Max_onset['B'],
        #            'OMax_C':Max_onset['C'],
        'OMax_I': Max_onset['I'],
        'TotalMax': TotalMax * val['resp'].fs,
        'SinglesMax': SinglesMax * val['resp'].fs,
        #            'r_lin_C':r_fit_linmodel['C'],
        'r_lin_I': r_fit_linmodel['I'],
        #            'r_lin_C_NM':r_fit_linmodel_NM['C'],
        'r_lin_I_NM': r_fit_linmodel_NM['I'],
        #            'r_ceil_C':r_ceil_linmodel['C'],
        'r_ceil_I': r_ceil_linmodel['I'],
        #            'MEnh_C':mean_enh['C'],
        'MEnh_I': mean_enh['I'],
        #            'MSupp_C':mean_supp['C'],
        'MSupp_I': mean_supp['I'],
        #            'EnhP_C':EnhP['C'],
        #        'EnhP_I':EnhP['I'],
        #            'SuppP_C':SuppP['C'],
        #        'SuppP_I':SuppP['I'],
        #            'DualAboveZeroP_C':DualAboveZeroP['C'],
        #        'DualAboveZeroP_I':DualAboveZeroP['I'],
        #            'r_dual_A_C':r_dual_A['C'],
        'r_dual_A_I': r_dual_A['I'],
        #            'r_dual_B_C':r_dual_B['C'],
        'r_dual_B_I': r_dual_B['I'],
        #            'r_dual_A_C_nc':r_dual_A_nc['C'],
        'r_dual_A_I_nc': r_dual_A_nc['I'],
        #            'r_dual_B_C_nc':r_dual_B_nc['C'],
        'r_dual_B_I_nc': r_dual_B_nc['I'],
        #            'r_dual_A_C_bal':r_dual_A_bal['C'],
        'r_dual_A_I_bal': r_dual_A_bal['I'],
        #            'r_dual_B_C_bal':r_dual_B_bal['C'],
        'r_dual_B_I_bal': r_dual_B_bal['I'],
        #            'r_lin_A_C':r_lin_A['C'],
        'r_lin_A_I': r_lin_A['I'],
        #            'r_lin_B_C':r_lin_B['C'],
        'r_lin_B_I': r_lin_B['I'],
        #            'r_lin_A_C_nc':r_lin_A_nc['C'],
        'r_lin_A_I_nc': r_lin_A_nc['I'],
        #            'r_lin_B_C_nc':r_lin_B_nc['C'],
        'r_lin_B_I_nc': r_lin_B_nc['I'],
        #            'r_lin_A_C_bal':r_lin_A_bal['C'],
        'r_lin_A_I_bal': r_lin_A_bal['I'],
        #            'r_lin_B_C_bal':r_lin_B_bal['C'],
        'r_lin_B_I_bal': r_lin_B_bal['I'],
        'r_A_B': r_A_B,
        'r_A_B_nc': r_A_B_nc,
        'rAAm': rAAm,
        'rBBm': rBBm,
        'rAA': rAA,
        'rBB': rBB,
        'rII': rII,
        #           'rCC':rCC
        'rAA_nc': rAA_nc,
        'rBB_nc': rBB_nc,
        'mean_nsA': mean_nsA,
        'mean_nsB': mean_nsB,
        'min_nsA': min_nsA,
        'min_nsB': min_nsB,
        'SR': SR,
        'SR_std': SR_std,
        'SR_av_std': SR_av_std,
        'norm_spont': norm_spont,
        'spont_rate': spont_rate,
        'params': params,
        'corcoef': corcoef,
        'avg_resp': avg_resp,
        'snr': snr,
        'pair_names': twostims,
        'suppression': supp_array,
        'FR': FR_array,
        'rec': rec,
        'animal': cellid[:3]
    }
Esempio n. 25
0
# siteids
sites = ['CRD009b', 'CRD010b', 'CRD011c', 'CRD012b', 'CRD013b', 'CRD016c', 'CRD017c', 'CRD018d', 'CRD019b']
for site in sites:
    if os.path.isdir(os.path.join(fig_path, site)):
        pass
    else:
        os.mkdir(os.path.join(fig_path, site))
    
    site_path = os.path.join(fig_path, site)
    # get parmfiles
    sql = "SELECT sCellFile.cellid, sCellFile.respfile, gDataRaw.resppath from sCellFile INNER JOIN" \
               " gCellMaster ON (gCellMaster.id=sCellFile.masterid) INNER JOIN" \
               " gDataRaw ON (sCellFile.rawid=gDataRaw.id)" \
               " WHERE gCellMaster.siteid=%s" \
               " and gDataRaw.runclass='TBP' and gDataRaw.bad=0"
    d = nd.pd_query(sql, (site,))
    d['parmfile'] = [f.replace('.spk.mat', '.m') for f in d['respfile']]
    parmfiles = np.unique(np.sort([os.path.join(d['resppath'].iloc[i], d['parmfile'].iloc[i]) for i in range(d.shape[0])])).tolist()
    manager = BAPHYExperiment(parmfiles)
    rec = manager.get_recording(**options)
    rec['resp'] = rec['resp'].rasterize()

    # find / sort epoch names
    files = [f for f in rec['resp'].epochs.name.unique() if 'FILE_' in f]
    targets = [f for f in rec['resp'].epochs.name.unique() if 'TAR_' in f]
    catch = [f for f in rec['resp'].epochs.name.unique() if 'CAT_' in f]

    sounds = targets + catch
    ref_stims = [x for x in rec['resp'].epochs.name.unique() if 'STIM_' in x]
    idx = np.argsort([int(s.split('_')[-1]) for s in ref_stims])
    ref_stims = np.array(ref_stims)[idx].tolist()
# ======================================================================

dfpath = '/auto/users/hellerc/code/projects/rewardLearning/R01_renewal_figs/results/'
runclass = 'BVT'
animal = 'Cordyceps'
earliest_date = '2019_11_09'

#animal = 'Drechsler'
#earliest_date = '2019_10_14'

an_regex = "%"+animal+"%"
min_trials = 50

# get list of all training parmfiles
sql = "SELECT parmfile, resppath, gDataRaw.id, cellid, gData.svalue, gData.value FROM gDataRaw INNER JOIN gData on (gDataRaw.id=gData.rawid and gData.name='Tar_Frequencies') WHERE runclass=%s and resppath like %s and training = 1 and bad=0 and trials>%s and behavior='active'"
parmfiles1 = nd.pd_query(sql, (runclass, an_regex, min_trials))
parmfiles1 = parmfiles1.set_index('id')

sql = "SELECT parmfile, resppath, gDataRaw.id, cellid, gData.svalue FROM gDataRaw INNER JOIN gData on (gDataRaw.id=gData.rawid and gData.name='Behave_PumpDuration') WHERE runclass=%s and resppath like %s and training = 1 and bad = 0 and trials>%s and behavior='active'"
parmfiles = nd.pd_query(sql, (runclass, an_regex, min_trials))
parmfiles = parmfiles.set_index('id')
parmfiles = parmfiles.rename(columns={'svalue': 'pumpdur'})

parmfiles = pd.concat([parmfiles['pumpdur'], parmfiles1], axis=1, join='outer')

# screen for dates
parmfiles['date'] = [dt.datetime.strptime('-'.join(x.split('_')[1:-2]), '%Y-%m-%d') for x in parmfiles.parmfile]
ed = dt.datetime.strptime(earliest_date, '%Y_%m_%d')
parmfiles = parmfiles[parmfiles.date > ed]

options = {'pupil': False, 'rasterfs': 100}
Esempio n. 27
0
def mark_complete(pupilfiles):
    """
    Save predictions for all files and update celldb to mark these analyses as complete.
    
    !!!!!!!!!!!!!!!! Take care using this!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    
    It's a good idea to first use pupil_browser to validate the results
    of the fit. However, if you collect many videos in a row and the hardware set up 
    doesn't change, it's likely you can just perform the pupil_browser QC on the first 
    video in the sequence and then batch process the rest of the videos using this function.
    """
    for vf in pupilfiles:
        video_name = os.path.splitext(os.path.split(vf)[-1])[0]

        fn = video_name + '.pickle'
        fn_mat = video_name + '.mat'
        fp = os.path.split(vf)[0]
        save_path = os.path.join(fp, 'sorted', fn)
        # for matlab loading
        mat_fn = os.path.join(fp, fn_mat)

        sorted_dir = os.path.split(save_path)[0]

        if os.path.isdir(sorted_dir) != True:
            # create sorted directory and force to be world writeable
            os.system("mkdir {}".format(sorted_dir))
            os.system("chmod a+w {}".format(sorted_dir))
            print("created new directory {0}".format(sorted_dir))
        else:
            pass

        try:
            # load predictions
            pred_path = os.path.join(sorted_dir, fn.replace('.pickle', '_pred.pickle'))
            with open(pred_path, 'rb') as fp:
                save_dict = pickle.load(fp)

            # No excluded frames options for batch saving
            save_dict['cnn']['excluded_frames'] = []

            x_diff = np.diff(save_dict['cnn']['x'])
            y_diff = np.diff(save_dict['cnn']['y'])
            d = np.sqrt((x_diff ** 2) + (y_diff ** 2))
            d[-1] = 0
            d = np.concatenate((d, np.zeros(1)))
            save_dict['cnn']['eyespeed'] = d
            with open(save_path, 'wb') as fp:
                    pickle.dump(save_dict, fp, protocol=pickle.HIGHEST_PROTOCOL)

            scipy.io.savemat(mat_fn, save_dict)

            # finally, update celldb to mark pupil as analyzed
            get_file1 = "SELECT eyecalfile from gDataRaw where eyecalfile LIKE %s" 
            out1 = nd.pd_query(get_file1, params=("%" + vf.replace('/auto/data/daq/', '') + "%",))
            og_video_path = out1.iloc[0][0]
            sql = "UPDATE gDataRaw SET eyewin=2 WHERE eyecalfile='{}'".format(og_video_path)

            nd.sql_command(sql)

            log.info("Saved analysis successfully for {}".format(video_name))
        
        except:
            log.info("No saved pupil analysis {}".format(os.path.join(sorted_dir, fn.replace('.pickle', '_pred.pickle'))))