def ddr_sum_all(batch=331, modelname_base=None): siteids, cellids = db.get_batch_sites(batch=batch) res = [] for site in siteids: try: labels, cc, mse, pupil_range = ddr_pred_site_sim( site, batch=batch, modelname_base=modelname_base, skip_plot=True) labelsabs = [l + '_abs_cc' for l in labels] labelscc = [l + '_cc' for l in labels] labelsraw = [l + '_raw_cc' for l in labels] labelsmabs = [l + '_abs_mse' for l in labels] labelsmse = [l + '_mse' for l in labels] labelsmraw = [l + '_raw_mse' for l in labels] d = { 'site': site, 'batch': batch, 'pupil_range': pupil_range, 'cc_base': cc[0, 0] } for i in range(len(labelscc)): d[labelsabs[i]] = cc[i, 0] d[labelscc[i]] = cc[i, 1] d[labelsraw[i]] = cc[i, 2] d[labelsmabs[i]] = mse[i, 0] d[labelsmse[i]] = mse[i, 1] d[labelsmraw[i]] = mse[i, 2] res.append(pd.DataFrame(d, index=[0])) except: print(f"Skipping site {site}") plt.close() df = pd.concat(res, ignore_index=True) labels = [ labelsabs, labelscc, labelsraw, labelsmabs, labelsmse, labelsmraw ] return df, labels
from pathlib import Path from scipy.io import wavfile import matplotlib.pyplot as plt from nems_lbhb.baphy_experiment import BAPHYExperiment from nems.analysis.gammatone.gtgram import gtgram import nems.epoch as ep from nems import db from nems_lbhb.exacloud.queue_exacloud_job import enqueue_exacloud_models batch = 338 siteids, cellids = db.get_batch_sites(batch) #enqueue_models(celllist, batch, modellist, force_rerun=False, # user="******", codeHash="master", jerbQuery='', # executable_path=None, script_path=None, # priority=1, GPU_job=0, reserve_gb=0) executable_path = '/home/svd/bin/miniconda3/envs/tfg/bin/python' script_path = '/auto/users/svd/python/nems/scripts/fit_single.py' GPU_job = True modelnames = [ "gtgram.fs100.ch18-ld.pop-norm.l1-sev_wc.Nx60-fir.1x20x60-wc.60xR-lvl.R-dexp.R_tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4", "gtgram.fs100.ch18.bin-ld.pop-norm.l1-sev_wc.Nx60-fir.1x20x60-wc.60xR-lvl.R-dexp.R_tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4", "gtgram.fs100.ch18-ld.pop-norm.l1-sev_wc.Nx40-fir.1x20x40-relu.40.f-wc.40x60-relu.60.f-wc.60xR-lvl.R-dexp.R_tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4", "gtgram.fs100.ch18.bin-ld.pop-norm.l1-sev_wc.Nx40-fir.1x20x40-relu.40.f-wc.40x60-relu.60.f-wc.60xR-lvl.R-dexp.R_tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4", "gtgram.fs100.ch18-ld.pop-norm.l1-sev_wc.Nx40-fir.1x20x40-relu.40.f-wc.40x50-relu.50.f-wc.50xR-lvl.R-dexp.R_tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4", "gtgram.fs100.ch18.bin-ld.pop-norm.l1-sev_wc.Nx40-fir.1x20x40-relu.40.f-wc.40x50-relu.50.f-wc.50xR-lvl.R-dexp.R_tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4" ]
def initialize_with_prefit(modelspec, meta, area="A1", cellid=None, siteid=None, batch=322, pre_batch=None, use_matched=False, use_simulated=False, use_full_model=False, prefit_type=None, freeze_early=True, IsReload=False, **ctx): """ replace early layers of model with fit parameters from a "standard" model ... for now that's model with the same architecture fit to the NAT4 dataset for dnn single: initial model: modelname = "ozgf.fs100.ch18-ld-norm.l1-sev_wc.18x4.g-fir.1x25x4-relu.4.f-wc.4x1-lvl.1-dexp.1_tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4.es20" use initial as pre-fit: modelname = "ozgf.fs100.ch18-ld-norm.l1-sev_wc.18x4.g-fir.1x25x4-relu.4.f-wc.4x1-lvl.1-dexp.1_prefit-tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4.es20" """ if IsReload: return {} xi = find_module("weight_channels", modelspec, find_all_matches=True) if len(xi) == 0: raise ValueError(f"modelspec has not weight_channels layer to align") copy_layers = xi[-1] freeze_layer_count = xi[-1] batch = int(meta['batch']) modelname_parts = meta['modelname'].split("_") if use_simulated: guess = '.'.join(['SIM000a', modelname_parts[1]]) # remove problematic characters guess = re.sub('[:]', '', guess) guess = re.sub('[,]', '', guess) if len(guess) > 100: # If modelname is too long, causes filesystem errors. guess = guess[:75] + '...' + str(hashlib.sha1(guess.encode('utf-8')).hexdigest()[:20]) old_uri = f"/auto/data/nems_db/modelspecs/{guess}/modelspec.0000.json" log.info('loading saved modelspec from: ' + old_uri) new_ctx = load_phi(modelspec, prefit_uri=old_uri, copy_layers=copy_layers) return new_ctx elif prefit_type == 'init': # use full pop file - SVD work in progress. current best? load_string_pop = "ozgf.fs100.ch18.pop-loadpop-norm.l1-popev" fit_string_pop = "tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4" pre_part = load_string_pop if len(modelname_parts[2].split("-")) > 2: post_part = "-".join(modelname_parts[2].split("-")[1:-1]) else: post_part = "-".join(modelname_parts[2].split("-")[1:]) model_search = "_".join([pre_part, modelname_parts[1], post_part]) if pre_batch is None: pre_batch = batch if pre_batch in [322, 334]: pre_cellid = 'ARM029a-07-6' elif pre_batch == 323: pre_cellid = 'ARM017a-01-9' else: raise ValueError(f"batch {pre_batch} prefit not implemented yet.") log.info(f"prefit cellid={pre_cellid}, skipping init_fit") copy_layers = len(modelspec) elif use_full_model: # use full pop file - SVD work in progress. current best? load_string_pop = "ozgf.fs100.ch18.pop-loadpop-norm.l1-popev" fit_string_pop = "tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4" if prefit_type == 'heldout': pre_part = "ozgf.fs100.ch18.pop-loadpop.hs-norm.l1-popev" elif prefit_type == 'matched': pre_part = "ozgf.fs100.ch18.pop-loadpop.hm-norm.l1-popev" elif prefit_type == 'matched_half': # 50% est data (matched cell excluded) pre_part = "ozgf.fs100.ch18.pop-loadpop.hm-norm.l1-popev.k50" elif prefit_type == 'matched_quarter': # 25% est data pre_part = "ozgf.fs100.ch18.pop-loadpop.hm-norm.l1-popev.k25" elif prefit_type == 'matched_fifteen': # 15% est data pre_part = "ozgf.fs100.ch18.pop-loadpop.hm-norm.l1-popev.k15" elif prefit_type == 'matched_ten': # 10% est data pre_part = "ozgf.fs100.ch18.pop-loadpop.hm-norm.l1-popev.k10" elif prefit_type == 'heldout_half': # 50% est data, cell excluded (is this a useful condition?) pre_part = "ozgf.fs100.ch18.pop-loadpop.hs-norm.l1-popev.k50" elif prefit_type == 'heldout_quarter': # 25% est data pre_part = "ozgf.fs100.ch18.pop-loadpop.hs-norm.l1-popev.k25" elif prefit_type == 'heldout_fifteen': # 15% est data pre_part = "ozgf.fs100.ch18.pop-loadpop.hs-norm.l1-popev.k15" elif prefit_type == 'heldout_ten': # 10% est data pre_part = "ozgf.fs100.ch18.pop-loadpop.hs-norm.l1-popev.k10" elif 'R.q.s' in modelname_parts[1]: pre_part = "ozgf.fs100.ch18-ld-norm.l1-sev" elif 'ch32' in modelname_parts[0]: pre_part = "ozgf.fs100.ch32.pop-loadpop-norm.l1-popev" elif 'ch64' in modelname_parts[0]: pre_part = "ozgf.fs100.ch64.pop-loadpop-norm.l1-popev" elif batch==333: # not pre-concatenated recording. different stim for each site, # so fit each site separately (unless titan) pre_part = "ozgf.fs100.ch18-ld-norm.l1-sev" else: #load_string_pop = "ozgf.fs100.ch18.pop-loadpop-norm.l1-popev" pre_part = load_string_pop if prefit_type == 'titan': if batch==333: pre_part = load_string_pop post_part = "tfinit.n.mc50.lr1e3.et4.es20-newtf.n.mc100.lr1e4" else: post_part = "tfinit.n.mc25.lr1e3.es20-newtf.n.mc100.lr1e4.exa" else: #post_part = "tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4.es20" post_part = fit_string_pop if modelname_parts[2].endswith(".l2:5") or modelname_parts[2].endswith(".l2:5-dstrf") or modelname_parts[2].endswith("ver5"): post_part += ".l2:5" elif modelname_parts[2].endswith(".l2:4") or modelname_parts[2].endswith(".l2:4-dstrf") or modelname_parts[2].endswith("ver4"): post_part += ".l2:4" elif modelname_parts[2].endswith(".l2:4.ver2"): post_part += ".l2:4.ver2" elif modelname_parts[2].endswith("ver2"): post_part += ".ver2" elif modelname_parts[2].endswith("ver1"): post_part += ".ver1" model_search = "_".join([pre_part, modelname_parts[1], post_part]) if pre_batch is None: pre_batch = batch # this is a single-cell fit if type(cellid) is list: cellid = cellid[0] siteid = cellid.split("-")[0] allsiteids, allcellids = nd.get_batch_sites(batch, modelname_filter=model_search) allsiteids = [s.split(".")[0] for s in allsiteids] if (batch==323) and (pre_batch==322): matchfile=os.path.dirname(__file__) + "/projects/pop_model_scripts/snr_subset_map.csv" df = pd.read_csv(matchfile, index_col=0) pre_cellid = df.loc[df.PEG_cellid==cellid, 'A1_cellid'].values[0] elif (batch==322) and (pre_batch==323): matchfile=os.path.dirname(__file__) + "/projects/pop_model_scripts/snr_subset_map.csv" df = pd.read_csv(matchfile, index_col=0) pre_cellid = df.loc[df.A1_cellid==cellid, 'PEG_cellid'].values[0] elif siteid in allsiteids: # don't need to generalize, load from actual fit pre_cellid = cellid elif batch in [322, 334]: pre_cellid = 'ARM029a-07-6' elif pre_batch == 323: pre_cellid = 'ARM017a-01-9' else: raise ValueError(f"batch {batch} prefit not implemented yet.") log.info(f"prefit cellid={pre_cellid} prefit batch={pre_batch}") elif prefit_type == 'site': # exact same model, just fit for site, now being fit for single cell pre_parts = modelname_parts[0].split("-") post_parts = modelname_parts[2].split("-") model_search = modelname_parts[0] + "%%" + modelname_parts[1] + "%%" + "-".join(post_parts[1:]) pre_cellid = cellid[0] pre_batch = batch elif prefit_type is not None: # this is a single-cell fit if type(cellid) is list: cellid = cellid[0] if prefit_type=='heldout': if siteid is None: siteid=cellid.split("-")[0] cellids, this_perf, alt_cellid, alt_perf = _matching_cells(batch=batch, siteid=siteid) pre_cellid = [c_alt for c,c_alt in zip(cellids,alt_cellid) if c==cellid][0] log.info(f"heldout init for {cellid} is {pre_cellid}") else: pre_cellid = cellid log.info(f"matched cellid prefit for {cellid}") if pre_batch is None: pre_batch = batch post_part = "tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4" if modelname_parts[2].endswith(".l2:5") or modelname_parts[2].endswith(".l2:5-dstrf"): post_part += ".l2:5" elif modelname_parts[2].endswith(".l2:4") or modelname_parts[2].endswith(".l2:4-dstrf"): post_part += ".l2:4" elif modelname_parts[2].endswith(".l2:4.ver2"): post_part += ".l2:4.ver2" elif modelname_parts[2].endswith("ver2"): post_part += ".ver2" modelname_parts[2] = post_part model_search="_".join(modelname_parts) elif modelname_parts[1].endswith(".1"): raise ValueError("deprecated prefit initialization?") # this is a single-cell fit if type(cellid) is list: cellid = cellid[0] if use_matched: # determine matched cell for this heldout cell if siteid is None: siteid=cellid.split("-")[0] cellids, this_perf, alt_cellid, alt_perf = _matching_cells(batch=batch, siteid=siteid) pre_cellid = [c_alt for c,c_alt in zip(cellids,alt_cellid) if c==cellid][0] log.info(f"matched cell for {cellid} is {pre_cellid}") else: pre_cellid = cellid[0] log.info(f"cellid prefit for {cellid}") if pre_batch is None: pre_batch = batch #postparts = modelname_parts[2].split("-") #postparts = [s for s in postparts if not(s.startswith("prefit"))] #modelname_parts[2]="-".join(postparts) modelname_parts[2] = "tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4.es20" model_search="_".join(modelname_parts) else: pre_parts = modelname_parts[0].split("-") post_parts = modelname_parts[2].split("-") post_part = "tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4.ver2" model_search = pre_parts[0] + ".pop%%" + modelname_parts[1] + "%%" + post_part #ozgf.fs100.ch18.pop-loadpop-norm.l1-popev #wc.18x70.g-fir.1x15x70-relu.70.f-wc.70x80-fir.1x10x80-relu.80.f-wc.80x100-relu.100-wc.100xR-lvl.R-dexp.R #tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4.ver2 # hard-coded to use an A1 model!!!! if pre_batch == 322: pre_cellid = 'ARM029a-07-6' elif area == "A1": pre_cellid = 'ARM029a-07-6' pre_batch = 322 else: raise ValueError(f"area {area} prefit not implemented") log.info(f"model_search: {model_search}") sql = f"SELECT * FROM Results WHERE batch={pre_batch} and cellid='{pre_cellid}' and modelname like '{model_search}'" #log.info(sql) d = nd.pd_query(sql) #old_uri = adjust_uri_prefix(d['modelpath'][0] + '/modelspec.0000.json') old_uri = adjust_uri_prefix(d['modelpath'][0]) log.info(f"Importing parameters from {old_uri}") mspaths = [f"{old_uri}/modelspec.{i:04d}.json" for i in range(modelspec.cell_count)] print(mspaths) prefit_ctx = xforms.load_modelspecs([], uris=mspaths, IsReload=False) #_, prefit_ctx = xform_helper.load_model_xform( # cellid=pre_cellid, batch=pre_batch, # modelname=d['modelname'][0], eval_model=False) new_ctx = load_phi(modelspec, prefit_modelspec=prefit_ctx['modelspec'], copy_layers=copy_layers) if freeze_early: new_ctx['freeze_layers'] = list(np.arange(freeze_layer_count)) if prefit_type == 'init': new_ctx['skip_init'] = True return new_ctx
def pop_file(stimfmt='ozgf', batch=None, cellid=None, rasterfs=50, chancount=18, siteid=None, loadkey=None, **options): siteid = siteid.split("-")[0] subsetstr = [] sitelist = [] if siteid == 'ALLCELLS': if (batch in [322]): subsetstr = ["NAT4v2", "NAT3", "NAT1"] elif (batch in [323]): subsetstr = ["NAT4"] elif (batch in [333]): #runclass="OLP" #sql="SELECT sRunData.cellid,gData.svalue,gData.rawid FROM sRunData INNER JOIN" +\ # " sCellFile ON sRunData.cellid=sCellFile.cellid " +\ # " INNER JOIN gData ON" + \ # " sCellFile.rawid=gData.rawid AND gData.name='Ref_Combos'" +\ # " AND gData.svalue='Manual'" +\ # " INNER JOIN gRunClass on gRunClass.id=sCellFile.runclassid" +\ # f" WHERE sRunData.batch={batch} and gRunClass.name='{runclass}'" #d = nd.pd_query(sql) #d['siteid'] = d['cellid'].apply(nd.get_siteid) #sitelist = d['siteid'].unique() modelname_filter = 'ozgf.fs100.ch18-ld-norm.l1-sev_wc.18x4.g-fir.4x25-lvl.1-dexp.1_tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4' sitelist, _ = nd.get_batch_sites(batch, modelname_filter=modelname_filter) #sitelist=sitelist[:4] #log.info('limiting sitelist to 4 entries!!!!!!!!!!!!!!!!!') else: raise ValueError(f'ALLCELLS not supported for batch {batch}') elif ((batch == 272) and (siteid == 'none')) or (siteid in [ 'bbl086b', 'TAR009d', 'TAR010c', 'TAR017b' ]): subsetstr = ["NAT1"] elif siteid in [ 'none', 'NAT3', 'AMT003c', 'AMT005c', 'AMT018a', 'AMT020a', 'AMT023d', 'bbl099g', 'bbl104h', 'BRT026c', 'BRT032e', 'BRT033b', 'BRT034f', 'BRT037b', 'BRT038b', 'BRT039c', 'AMT031a', 'AMT032a' ]: # Should use NAT3 as siteid going forward for better readability, # but left other options here for backwards compatibility. subsetstr = ["NAT3"] elif (batch in [322, 323, 333]) or (siteid == 'NAT4'): subsetstr = ["NAT4v2"] else: raise ValueError('site not known for popfile') use_API = get_setting('USE_NEMS_BAPHY_API') uri_root = '/auto/data/nems_db/recordings/' recording_uri_list = [] #max_sites = 2; max_sites = 12 log.info(f"TRUNCATING MULTI-FILE DATA AT {max_sites} RECORDINGS") for s in sitelist[:max_sites]: recording_uri = generate_recording_uri(batch=batch, cellid=s, stimfmt=stimfmt, rasterfs=rasterfs, chancount=chancount, **options) log.info(f'loading {recording_uri}') #if use_API: # host = 'http://'+get_setting('NEMS_BAPHY_API_HOST')+":"+str(get_setting('NEMS_BAPHY_API_PORT')) # recording_uri = host + '/recordings/' + str(batch) + '/' + recname + '.tgz' #else: # recording_uri = '{}{}/{}.tgz'.format(uri_root, batch, recname) recording_uri_list.append(recording_uri) for s in subsetstr: recname = f"{s}_{stimfmt}.fs{rasterfs}.ch{chancount}" log.info(f'loading {recname}') #data_file = '{}{}/{}.tgz'.format(uri_root, batch, recname) if use_API: host = 'http://' + get_setting('NEMS_BAPHY_API_HOST') + ":" + str( get_setting('NEMS_BAPHY_API_PORT')) recording_uri = host + '/recordings/' + str( batch) + '/' + recname + '.tgz' else: recording_uri = '{}{}/{}.tgz'.format(uri_root, batch, recname) recording_uri_list.append(recording_uri) if len(subsetstr) == 1: return recording_uri else: return recording_uri_list
shortnames = ['2D-CNN', '1D-CNN', '1Dx2-CNN', 'pop-LN', 'single-CNN'] #shortnames=['conv2d', 'conv1d', 'conv1dx2', 'ln-pop', 'dnn-sing'] #shortnames=['conv1d','conv1dx2','dnn-sing'] HELDOUT_CROSSBATCH = cross_batch_modelname = ( "ozgf.fs100.ch18-ld-norm.l1-sev_" "wc.18x70.g-fir.1x15x70-relu.70.f-wc.70x80-fir.1x10x80-relu.80.f-" "wc.80x100-relu.100-wc.100xR-lvl.R-dexp.R_" "prefit.hs.b322-tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4") #### # A1 expanded dataset (v>=2) #### if VERSION > 1: NAT4_A1_SITES, rep_cellids = nd.get_batch_sites(322, POP_MODELS[1]) else: NAT4_A1_SITES = [ 'ARM029a', 'ARM030a', 'ARM031a', 'ARM032a', 'ARM033a', 'CRD016d', 'CRD017c', 'DRX006b.e1:64', 'DRX006b.e65:128', 'DRX007a.e1:64', 'DRX007a.e65:128', 'DRX008b.e1:64', 'DRX008b.e65:128',