示例#1
0
def ddr_sum_all(batch=331, modelname_base=None):
    siteids, cellids = db.get_batch_sites(batch=batch)

    res = []
    for site in siteids:
        try:
            labels, cc, mse, pupil_range = ddr_pred_site_sim(
                site,
                batch=batch,
                modelname_base=modelname_base,
                skip_plot=True)
            labelsabs = [l + '_abs_cc' for l in labels]
            labelscc = [l + '_cc' for l in labels]
            labelsraw = [l + '_raw_cc' for l in labels]
            labelsmabs = [l + '_abs_mse' for l in labels]
            labelsmse = [l + '_mse' for l in labels]
            labelsmraw = [l + '_raw_mse' for l in labels]

            d = {
                'site': site,
                'batch': batch,
                'pupil_range': pupil_range,
                'cc_base': cc[0, 0]
            }
            for i in range(len(labelscc)):
                d[labelsabs[i]] = cc[i, 0]
                d[labelscc[i]] = cc[i, 1]
                d[labelsraw[i]] = cc[i, 2]
                d[labelsmabs[i]] = mse[i, 0]
                d[labelsmse[i]] = mse[i, 1]
                d[labelsmraw[i]] = mse[i, 2]

            res.append(pd.DataFrame(d, index=[0]))
        except:
            print(f"Skipping site {site}")
            plt.close()

    df = pd.concat(res, ignore_index=True)

    labels = [
        labelsabs, labelscc, labelsraw, labelsmabs, labelsmse, labelsmraw
    ]

    return df, labels
示例#2
0
from pathlib import Path
from scipy.io import wavfile
import matplotlib.pyplot as plt

from nems_lbhb.baphy_experiment import BAPHYExperiment

from nems.analysis.gammatone.gtgram import gtgram
import nems.epoch as ep
from nems import db
from nems_lbhb.exacloud.queue_exacloud_job import enqueue_exacloud_models

batch = 338
siteids, cellids = db.get_batch_sites(batch)

#enqueue_models(celllist, batch, modellist, force_rerun=False,
#                   user="******", codeHash="master", jerbQuery='',
#                   executable_path=None, script_path=None,
#                   priority=1, GPU_job=0, reserve_gb=0)

executable_path = '/home/svd/bin/miniconda3/envs/tfg/bin/python'
script_path = '/auto/users/svd/python/nems/scripts/fit_single.py'
GPU_job = True

modelnames = [
    "gtgram.fs100.ch18-ld.pop-norm.l1-sev_wc.Nx60-fir.1x20x60-wc.60xR-lvl.R-dexp.R_tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4",
    "gtgram.fs100.ch18.bin-ld.pop-norm.l1-sev_wc.Nx60-fir.1x20x60-wc.60xR-lvl.R-dexp.R_tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4",
    "gtgram.fs100.ch18-ld.pop-norm.l1-sev_wc.Nx40-fir.1x20x40-relu.40.f-wc.40x60-relu.60.f-wc.60xR-lvl.R-dexp.R_tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4",
    "gtgram.fs100.ch18.bin-ld.pop-norm.l1-sev_wc.Nx40-fir.1x20x40-relu.40.f-wc.40x60-relu.60.f-wc.60xR-lvl.R-dexp.R_tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4",
    "gtgram.fs100.ch18-ld.pop-norm.l1-sev_wc.Nx40-fir.1x20x40-relu.40.f-wc.40x50-relu.50.f-wc.50xR-lvl.R-dexp.R_tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4",
    "gtgram.fs100.ch18.bin-ld.pop-norm.l1-sev_wc.Nx40-fir.1x20x40-relu.40.f-wc.40x50-relu.50.f-wc.50xR-lvl.R-dexp.R_tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4"
]
示例#3
0
def initialize_with_prefit(modelspec, meta, area="A1", cellid=None, siteid=None, batch=322, pre_batch=None,
                           use_matched=False, use_simulated=False, use_full_model=False, 
                           prefit_type=None, freeze_early=True, IsReload=False, **ctx):
    """
    replace early layers of model with fit parameters from a "standard" model ... for now that's model with the same architecture fit
    to the NAT4 dataset
    
    for dnn single:
    initial model:
    modelname = "ozgf.fs100.ch18-ld-norm.l1-sev_wc.18x4.g-fir.1x25x4-relu.4.f-wc.4x1-lvl.1-dexp.1_tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4.es20"
    
    use initial as pre-fit:
    modelname = "ozgf.fs100.ch18-ld-norm.l1-sev_wc.18x4.g-fir.1x25x4-relu.4.f-wc.4x1-lvl.1-dexp.1_prefit-tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4.es20"

    """
    if IsReload:
        return {}

    xi = find_module("weight_channels", modelspec, find_all_matches=True)
    if len(xi) == 0:
        raise ValueError(f"modelspec has not weight_channels layer to align")

    copy_layers = xi[-1]
    freeze_layer_count = xi[-1]
    batch = int(meta['batch'])
    modelname_parts = meta['modelname'].split("_")
    
    if use_simulated:
        guess = '.'.join(['SIM000a', modelname_parts[1]])

        # remove problematic characters
        guess = re.sub('[:]', '', guess)
        guess = re.sub('[,]', '', guess)
        if len(guess) > 100:
            # If modelname is too long, causes filesystem errors.
            guess = guess[:75] + '...' + str(hashlib.sha1(guess.encode('utf-8')).hexdigest()[:20])

        old_uri = f"/auto/data/nems_db/modelspecs/{guess}/modelspec.0000.json"
        log.info('loading saved modelspec from: ' + old_uri)

        new_ctx = load_phi(modelspec, prefit_uri=old_uri, copy_layers=copy_layers)
        
        return new_ctx

    elif prefit_type == 'init':
        # use full pop file - SVD work in progress. current best?
        load_string_pop = "ozgf.fs100.ch18.pop-loadpop-norm.l1-popev"
        fit_string_pop = "tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4"

        pre_part = load_string_pop
        if len(modelname_parts[2].split("-")) > 2:
            post_part = "-".join(modelname_parts[2].split("-")[1:-1])
        else:
            post_part = "-".join(modelname_parts[2].split("-")[1:])

        model_search = "_".join([pre_part, modelname_parts[1], post_part])

        if pre_batch is None:
            pre_batch = batch
        if pre_batch in [322, 334]:
            pre_cellid = 'ARM029a-07-6'
        elif pre_batch == 323:
            pre_cellid = 'ARM017a-01-9'
        else:
            raise ValueError(f"batch {pre_batch} prefit not implemented yet.")

        log.info(f"prefit cellid={pre_cellid}, skipping init_fit")
        copy_layers = len(modelspec)

    elif use_full_model:
        
        # use full pop file - SVD work in progress. current best?
        load_string_pop = "ozgf.fs100.ch18.pop-loadpop-norm.l1-popev"
        fit_string_pop = "tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4"

        if prefit_type == 'heldout':
            pre_part = "ozgf.fs100.ch18.pop-loadpop.hs-norm.l1-popev"
        elif prefit_type == 'matched':
            pre_part = "ozgf.fs100.ch18.pop-loadpop.hm-norm.l1-popev"
        elif prefit_type == 'matched_half':
            # 50% est data (matched cell excluded)
            pre_part = "ozgf.fs100.ch18.pop-loadpop.hm-norm.l1-popev.k50"
        elif prefit_type == 'matched_quarter':
            # 25% est data
            pre_part = "ozgf.fs100.ch18.pop-loadpop.hm-norm.l1-popev.k25"
        elif prefit_type == 'matched_fifteen':
            # 15% est data
            pre_part = "ozgf.fs100.ch18.pop-loadpop.hm-norm.l1-popev.k15"
        elif prefit_type == 'matched_ten':
            # 10% est data
            pre_part = "ozgf.fs100.ch18.pop-loadpop.hm-norm.l1-popev.k10"
        elif prefit_type == 'heldout_half':
            # 50% est data, cell excluded (is this a useful condition?)
            pre_part = "ozgf.fs100.ch18.pop-loadpop.hs-norm.l1-popev.k50"
        elif prefit_type == 'heldout_quarter':
            # 25% est data
            pre_part = "ozgf.fs100.ch18.pop-loadpop.hs-norm.l1-popev.k25"
        elif prefit_type == 'heldout_fifteen':
            # 15% est data
            pre_part = "ozgf.fs100.ch18.pop-loadpop.hs-norm.l1-popev.k15"
        elif prefit_type == 'heldout_ten':
            # 10% est data
            pre_part = "ozgf.fs100.ch18.pop-loadpop.hs-norm.l1-popev.k10"
        elif 'R.q.s' in modelname_parts[1]:
            pre_part = "ozgf.fs100.ch18-ld-norm.l1-sev"
        elif 'ch32' in modelname_parts[0]:
            pre_part = "ozgf.fs100.ch32.pop-loadpop-norm.l1-popev"
        elif 'ch64' in modelname_parts[0]:
            pre_part = "ozgf.fs100.ch64.pop-loadpop-norm.l1-popev"
        elif batch==333:
            # not pre-concatenated recording. different stim for each site, 
            # so fit each site separately (unless titan)
            pre_part = "ozgf.fs100.ch18-ld-norm.l1-sev"
        else:
            #load_string_pop = "ozgf.fs100.ch18.pop-loadpop-norm.l1-popev"
            pre_part = load_string_pop

        if prefit_type == 'titan':
            if batch==333:
                pre_part = load_string_pop
                post_part = "tfinit.n.mc50.lr1e3.et4.es20-newtf.n.mc100.lr1e4"
            else:
                post_part = "tfinit.n.mc25.lr1e3.es20-newtf.n.mc100.lr1e4.exa"
        else:
            #post_part = "tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4.es20"
            post_part = fit_string_pop

        if modelname_parts[2].endswith(".l2:5") or modelname_parts[2].endswith(".l2:5-dstrf") or modelname_parts[2].endswith("ver5"):
            post_part += ".l2:5"
        elif modelname_parts[2].endswith(".l2:4") or modelname_parts[2].endswith(".l2:4-dstrf") or modelname_parts[2].endswith("ver4"):
            post_part += ".l2:4"
        elif modelname_parts[2].endswith(".l2:4.ver2"):
            post_part += ".l2:4.ver2"
        elif modelname_parts[2].endswith("ver2"):
            post_part += ".ver2"
        elif modelname_parts[2].endswith("ver1"):
            post_part += ".ver1"

        model_search = "_".join([pre_part, modelname_parts[1], post_part])
        if pre_batch is None:
            pre_batch = batch

        # this is a single-cell fit
        if type(cellid) is list:
            cellid = cellid[0]
        siteid = cellid.split("-")[0]
        allsiteids, allcellids = nd.get_batch_sites(batch, modelname_filter=model_search)
        allsiteids = [s.split(".")[0] for s in allsiteids]

        if (batch==323) and (pre_batch==322):
            matchfile=os.path.dirname(__file__) + "/projects/pop_model_scripts/snr_subset_map.csv"
            df = pd.read_csv(matchfile, index_col=0)
            pre_cellid = df.loc[df.PEG_cellid==cellid, 'A1_cellid'].values[0]
        elif (batch==322) and (pre_batch==323):
            matchfile=os.path.dirname(__file__) + "/projects/pop_model_scripts/snr_subset_map.csv"
            df = pd.read_csv(matchfile, index_col=0)
            pre_cellid = df.loc[df.A1_cellid==cellid, 'PEG_cellid'].values[0]

        elif siteid in allsiteids:
            # don't need to generalize, load from actual fit
            pre_cellid = cellid
        elif batch in [322, 334]:
            pre_cellid = 'ARM029a-07-6'
        elif pre_batch == 323:
            pre_cellid = 'ARM017a-01-9'
        else:
            raise ValueError(f"batch {batch} prefit not implemented yet.")
            
        log.info(f"prefit cellid={pre_cellid} prefit batch={pre_batch}")

    elif prefit_type == 'site':
        # exact same model, just fit for site, now being fit for single cell
        pre_parts = modelname_parts[0].split("-")
        post_parts = modelname_parts[2].split("-")
        model_search = modelname_parts[0] + "%%" + modelname_parts[1] + "%%" + "-".join(post_parts[1:])

        pre_cellid = cellid[0]
        pre_batch = batch
    elif prefit_type is not None:
        # this is a single-cell fit
        if type(cellid) is list:
            cellid = cellid[0]
            
        if prefit_type=='heldout':
            if siteid is None:
                siteid=cellid.split("-")[0]
            cellids, this_perf, alt_cellid, alt_perf = _matching_cells(batch=batch, siteid=siteid)

            pre_cellid = [c_alt for c,c_alt in zip(cellids,alt_cellid) if c==cellid][0]
            log.info(f"heldout init for {cellid} is {pre_cellid}")
        else:
            pre_cellid = cellid
            log.info(f"matched cellid prefit for {cellid}")
        if pre_batch is None:
            pre_batch = batch

        post_part = "tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4"
        if modelname_parts[2].endswith(".l2:5") or modelname_parts[2].endswith(".l2:5-dstrf"):
            post_part += ".l2:5"
        elif modelname_parts[2].endswith(".l2:4") or modelname_parts[2].endswith(".l2:4-dstrf"):
            post_part += ".l2:4"
        elif modelname_parts[2].endswith(".l2:4.ver2"):
            post_part += ".l2:4.ver2"
        elif modelname_parts[2].endswith("ver2"):
            post_part += ".ver2"
        modelname_parts[2] = post_part
        model_search="_".join(modelname_parts)

    elif modelname_parts[1].endswith(".1"):
        raise ValueError("deprecated prefit initialization?")
        # this is a single-cell fit
        if type(cellid) is list:
            cellid = cellid[0]
        
        if use_matched:
            # determine matched cell for this heldout cell
            if siteid is None:
                siteid=cellid.split("-")[0]
            cellids, this_perf, alt_cellid, alt_perf = _matching_cells(batch=batch, siteid=siteid)

            pre_cellid = [c_alt for c,c_alt in zip(cellids,alt_cellid) if c==cellid][0]
            log.info(f"matched cell for {cellid} is {pre_cellid}")
        else:
            pre_cellid = cellid[0]
            log.info(f"cellid prefit for {cellid}")
        if pre_batch is None:
            pre_batch = batch
        #postparts = modelname_parts[2].split("-")
        #postparts = [s for s in postparts if not(s.startswith("prefit"))]
        #modelname_parts[2]="-".join(postparts)
        modelname_parts[2] = "tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4.es20"
        model_search="_".join(modelname_parts)

    else:
        pre_parts = modelname_parts[0].split("-")
        post_parts = modelname_parts[2].split("-")    
        post_part = "tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4.ver2"
        model_search = pre_parts[0] + ".pop%%" + modelname_parts[1] + "%%" + post_part

        #ozgf.fs100.ch18.pop-loadpop-norm.l1-popev
        #wc.18x70.g-fir.1x15x70-relu.70.f-wc.70x80-fir.1x10x80-relu.80.f-wc.80x100-relu.100-wc.100xR-lvl.R-dexp.R
        #tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4.ver2


        # hard-coded to use an A1 model!!!!
        if pre_batch == 322:
            pre_cellid = 'ARM029a-07-6'
        elif area == "A1":
            pre_cellid = 'ARM029a-07-6'
            pre_batch = 322
        else:
            raise ValueError(f"area {area} prefit not implemented")

    log.info(f"model_search: {model_search}")

    sql = f"SELECT * FROM Results WHERE batch={pre_batch} and cellid='{pre_cellid}' and modelname like '{model_search}'"
    #log.info(sql)
    
    d = nd.pd_query(sql)
    #old_uri = adjust_uri_prefix(d['modelpath'][0] + '/modelspec.0000.json')
    old_uri = adjust_uri_prefix(d['modelpath'][0])
    log.info(f"Importing parameters from {old_uri}")

    mspaths = [f"{old_uri}/modelspec.{i:04d}.json" for i in range(modelspec.cell_count)]
    print(mspaths)
    prefit_ctx = xforms.load_modelspecs([], uris=mspaths, IsReload=False)

    #_, prefit_ctx = xform_helper.load_model_xform(
    #    cellid=pre_cellid, batch=pre_batch,
    #    modelname=d['modelname'][0], eval_model=False)
    new_ctx = load_phi(modelspec, prefit_modelspec=prefit_ctx['modelspec'], copy_layers=copy_layers)
    if freeze_early:
        new_ctx['freeze_layers'] = list(np.arange(freeze_layer_count))
    if prefit_type == 'init':
        new_ctx['skip_init'] = True
    return new_ctx
示例#4
0
def pop_file(stimfmt='ozgf',
             batch=None,
             cellid=None,
             rasterfs=50,
             chancount=18,
             siteid=None,
             loadkey=None,
             **options):

    siteid = siteid.split("-")[0]
    subsetstr = []
    sitelist = []
    if siteid == 'ALLCELLS':
        if (batch in [322]):
            subsetstr = ["NAT4v2", "NAT3", "NAT1"]
        elif (batch in [323]):
            subsetstr = ["NAT4"]
        elif (batch in [333]):
            #runclass="OLP"
            #sql="SELECT sRunData.cellid,gData.svalue,gData.rawid FROM sRunData INNER JOIN" +\
            #        " sCellFile ON sRunData.cellid=sCellFile.cellid " +\
            #        " INNER JOIN gData ON" + \
            #        " sCellFile.rawid=gData.rawid AND gData.name='Ref_Combos'" +\
            #        " AND gData.svalue='Manual'" +\
            #        " INNER JOIN gRunClass on gRunClass.id=sCellFile.runclassid" +\
            #        f" WHERE sRunData.batch={batch} and gRunClass.name='{runclass}'"
            #d = nd.pd_query(sql)

            #d['siteid'] = d['cellid'].apply(nd.get_siteid)
            #sitelist = d['siteid'].unique()
            modelname_filter = 'ozgf.fs100.ch18-ld-norm.l1-sev_wc.18x4.g-fir.4x25-lvl.1-dexp.1_tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4'
            sitelist, _ = nd.get_batch_sites(batch,
                                             modelname_filter=modelname_filter)

            #sitelist=sitelist[:4]
            #log.info('limiting sitelist to 4 entries!!!!!!!!!!!!!!!!!')
        else:
            raise ValueError(f'ALLCELLS not supported for batch {batch}')
    elif ((batch == 272) and (siteid == 'none')) or (siteid in [
            'bbl086b', 'TAR009d', 'TAR010c', 'TAR017b'
    ]):
        subsetstr = ["NAT1"]
    elif siteid in [
            'none', 'NAT3', 'AMT003c', 'AMT005c', 'AMT018a', 'AMT020a',
            'AMT023d', 'bbl099g', 'bbl104h', 'BRT026c', 'BRT032e', 'BRT033b',
            'BRT034f', 'BRT037b', 'BRT038b', 'BRT039c', 'AMT031a', 'AMT032a'
    ]:
        # Should use NAT3 as siteid going forward for better readability,
        # but left other options here for backwards compatibility.
        subsetstr = ["NAT3"]
    elif (batch in [322, 323, 333]) or (siteid == 'NAT4'):
        subsetstr = ["NAT4v2"]
    else:
        raise ValueError('site not known for popfile')
    use_API = get_setting('USE_NEMS_BAPHY_API')

    uri_root = '/auto/data/nems_db/recordings/'

    recording_uri_list = []
    #max_sites = 2;
    max_sites = 12
    log.info(f"TRUNCATING MULTI-FILE DATA AT {max_sites} RECORDINGS")
    for s in sitelist[:max_sites]:
        recording_uri = generate_recording_uri(batch=batch,
                                               cellid=s,
                                               stimfmt=stimfmt,
                                               rasterfs=rasterfs,
                                               chancount=chancount,
                                               **options)
        log.info(f'loading {recording_uri}')
        #if use_API:
        #    host = 'http://'+get_setting('NEMS_BAPHY_API_HOST')+":"+str(get_setting('NEMS_BAPHY_API_PORT'))
        #    recording_uri = host + '/recordings/' + str(batch) + '/' + recname + '.tgz'
        #else:
        #    recording_uri = '{}{}/{}.tgz'.format(uri_root, batch, recname)
        recording_uri_list.append(recording_uri)
    for s in subsetstr:
        recname = f"{s}_{stimfmt}.fs{rasterfs}.ch{chancount}"
        log.info(f'loading {recname}')
        #data_file = '{}{}/{}.tgz'.format(uri_root, batch, recname)

        if use_API:
            host = 'http://' + get_setting('NEMS_BAPHY_API_HOST') + ":" + str(
                get_setting('NEMS_BAPHY_API_PORT'))
            recording_uri = host + '/recordings/' + str(
                batch) + '/' + recname + '.tgz'
        else:
            recording_uri = '{}{}/{}.tgz'.format(uri_root, batch, recname)
        recording_uri_list.append(recording_uri)
    if len(subsetstr) == 1:
        return recording_uri
    else:
        return recording_uri_list
示例#5
0
shortnames = ['2D-CNN', '1D-CNN', '1Dx2-CNN', 'pop-LN', 'single-CNN']
#shortnames=['conv2d', 'conv1d', 'conv1dx2', 'ln-pop', 'dnn-sing']
#shortnames=['conv1d','conv1dx2','dnn-sing']

HELDOUT_CROSSBATCH = cross_batch_modelname = (
    "ozgf.fs100.ch18-ld-norm.l1-sev_"
    "wc.18x70.g-fir.1x15x70-relu.70.f-wc.70x80-fir.1x10x80-relu.80.f-"
    "wc.80x100-relu.100-wc.100xR-lvl.R-dexp.R_"
    "prefit.hs.b322-tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4")

####
# A1 expanded dataset (v>=2)
####
if VERSION > 1:
    NAT4_A1_SITES, rep_cellids = nd.get_batch_sites(322, POP_MODELS[1])
else:
    NAT4_A1_SITES = [
        'ARM029a',
        'ARM030a',
        'ARM031a',
        'ARM032a',
        'ARM033a',
        'CRD016d',
        'CRD017c',
        'DRX006b.e1:64',
        'DRX006b.e65:128',
        'DRX007a.e1:64',
        'DRX007a.e65:128',
        'DRX008b.e1:64',
        'DRX008b.e65:128',