def cv_train_from_mat(lbl_file, cdir, cv_info_file, models_run, view=0, skip_db=False, create_splits=True, dorun=False, run_type='status'): cv_info, in_info, label_info = read_cvinfo(lbl_file, cdir, cv_info_file, view) lbl = h5py.File(lbl_file, 'r') proj_name = apt.read_string(lbl['projname']) lbl.close() cvifileshort = os.path.basename(cv_info_file) cvifileshort = os.path.splitext(cvifileshort)[0] n_splits = max(cv_info) + 1 print("{} splits, {} rows in cvi, {} rows in lbl, projname {}".format( n_splits, len(cv_info), len(label_info), proj_name)) for sndx in range(n_splits): val_info = [l for ndx, l in enumerate(in_info) if cv_info[ndx] == sndx] trn_info = list(set(label_info) - set(val_info)) cur_split = [trn_info, val_info] exp_name = '{:s}__split{}'.format(cvifileshort, sndx) split_file = os.path.join(cdir, proj_name, exp_name) + '.json' if not skip_db and create_splits: assert not os.path.exists(split_file) with open(split_file, 'w') as f: json.dump(cur_split, f) # create the dbs if not skip_db: for train_type in models_run: conf = apt.create_conf(lbl_file, view, exp_name, cdir, train_type) conf.splitType = 'predefined' if train_type == 'deeplabcut': apt.create_deepcut_db(conf, split=True, split_file=split_file, use_cache=True) elif train_type == 'leap': apt.create_leap_db(conf, split=True, split_file=split_file, use_cache=True) else: apt.create_tfrecord(conf, split=True, split_file=split_file, use_cache=True) if dorun: for train_type in models_run: rapt.run_trainining(elblbubxp_name, train_type, view, run_type)
def read_cvinfo(lbl_file, cdir, cv_info_file, view=0): conf = apt.create_conf(lbl_file, view, 'cv_dummy', cdir, 'mdn') # net type irrelevant #lbl_movies, _ = multiResData.find_local_dirs(conf) #in_movies = [PoseTools.read_h5_str(data_info[k]) for k in data_info['movies'][0, :]] #assert lbl_movies == in_movies label_info = rapt.get_label_info(conf) cvi = h5py.File(cv_info_file, 'r') cv_info = apt.to_py(cvi['cvi'].value[:, 0].astype('int')) fr_info = apt.to_py(cvi['frame'].value[:, 0].astype('int')) m_info = apt.to_py(cvi['movieidx'].value[:, 0].astype('int')) if 'target' in cvi.keys(): t_info = apt.to_py(cvi['target'].value[:, 0].astype('int')) in_info = [(a, b, c) for a, b, c in zip(m_info, fr_info, t_info)] else: in_info = [(a, b, 0) for a, b in zip(m_info, fr_info)] diff1 = list(set(label_info) - set(in_info)) diff2 = list(set(in_info) - set(label_info)) print('Number of labels that exists in label file but not in mat file:{}'. format(len(diff1))) print('Number of labels that exists in mat file but not in label file:{}'. format(len(diff2))) # assert all([a == b for a, b in zip(in_info, label_info)]) return cv_info, in_info, label_info
def test_crop(): import trackStephenHead_KB as ts import APT_interface as apt import multiResData import cv2 from cvc import cvc import os import re import hdf5storage crop_reg_file = '/groups/branson/bransonlab/mayank/stephen_copy/crop_regression_params.mat' # lbl_file = '/groups/branson/bransonlab/mayank/stephen_copy/apt_cache/sh_trn4523_gtcomplete_cacheddata_bestPrms20180920_retrain20180920T123534_withGTres.lbl' lbl_file = '/groups/branson/bransonlab/mayank/stephen_copy/apt_cache/sh_trn4879_gtcomplete.lbl' crop_size = [[230, 350], [350, 350]] name = 'stephen_20181029' cache_dir = '/groups/branson/bransonlab/mayank/stephen_copy/apt_cache' bodylblfile = '/groups/branson/bransonlab/mayank/stephen_copy/fly2BodyAxis_lookupTable_Ben.csv' import h5py bodydict = {} f = open(bodylblfile, 'r') for l in f: lparts = l.split(',') if len(lparts) != 2: print("Error splitting body label file line %s into two parts" % l) raise exit(0) bodydict[int(lparts[0])] = lparts[1].strip() f.close() flynums = [[], []] crop_locs = [[], []] for view in range(2): conf = apt.create_conf(lbl_file, view, 'aa', cache_dir='/groups/branson/home/kabram/temp') movs = multiResData.find_local_dirs(conf)[0] for mov in movs: dirname = os.path.normpath(mov) dir_parts = dirname.split(os.sep) aa = re.search('fly_*(\d+)', dir_parts[-3]) flynum = int(aa.groups()[0]) if bodydict.has_key(flynum): cap = cv2.VideoCapture(mov) height = int(cap.get(cvc.FRAME_HEIGHT)) width = int(cap.get(cvc.FRAME_WIDTH)) cap.release() crop_locs[view].append( ts.get_crop_locs(bodydict[flynum], view, height, width)) # return x first flynums[view].append(flynum) hdf5storage.savemat( '/groups/branson/bransonlab/mayank/stephen_copy/auto_crop_locs_trn4879', { 'flynum': flynums, 'crop_locs': crop_locs })
def save_cv_results(lbl_file, cachedir, view, exp_name, net, model_file_short, out_dir, data_type, kwout, mdn_hm_floor=0.1, db_file=None): conf_pvlist = None if net == 'openpose': if data_type == 'bub': conf_pvlist = ['op_affinity_graph', op_af_graph_bub_noslash] else: assert False, "define aff graph" return_hmaps = (net == 'mdn') conf = apt.create_conf(lbl_file, view, exp_name, cachedir, net, conf_params=conf_pvlist) if db_file is None: db_file = os.path.join(conf.cachedir, 'val_TF.tfrecords') model_file = os.path.join(conf.cachedir, model_file_short) res = apt_expts.classify_db_all(conf, db_file, [model_file], net, return_hm=return_hmaps, hm_dec=1, hm_floor=mdn_hm_floor, hm_nclustermax=1) res.append(conf) out_file = "{}__vw{}__{}__{}.p".format(exp_name, view, net, kwout) out_file = os.path.join(out_dir, out_file) with open(out_file, 'w') as f: pickle.dump(res, f) print "saved {}".format(out_file)
def train(): import PoseUNet_resnet as PoseURes import tensorflow as tf dstr = PoseTools.datestr() cur_name = 'stephen_{}'.format(dstr) for view in range(2): conf = apt.create_conf(lbl_file, view=view, name=cur_name, cache_dir=cache_dir, net_type=model_type) update_conf(conf) apt.create_tfrecord(conf, False, use_cache=True) tf.reset_default_graph() self = PoseURes.PoseUMDN_resnet(conf, name='deepnet') self.train_data_name = 'traindata' self.train_umdn()
def predsingle(lbl_file, cachedir, view, exp_name, net, model_file_short, data_type): conf_pvlist = None if net == 'openpose': if data_type == 'bub': conf_pvlist = ['op_affinity_graph', op_af_graph_bub_noslash] else: assert False, "define aff graph" conf = apt.create_conf(lbl_file, view, exp_name, cachedir, net, conf_params=conf_pvlist) db_file = os.path.join(conf.cachedir, 'val_TF.tfrecords') model_file = os.path.join(conf.cachedir, model_file_short) extra_str = '' if net not in ['leap', 'openpose']: extra_str = '.index' tf_iterator = multiResData.tf_reader(conf, db_file, False) tf_iterator.batch_size = 1 read_fn = tf_iterator.next pred_fn, close_fn, _ = op.get_pred_fn(conf, model_file, name=None, rawpred=True) im, locs, info, _ = read_fn() print "im.shape is {}".format(im.shape) predmaps = pred_fn(im) close_fn() return predmaps, im, locs, info
def main(argv): parser = argparse.ArgumentParser() parser.add_argument("-s", dest="sfilename", help="text file with list of side view videos", required=True) parser.add_argument( "-f", dest="ffilename", help= "text file with list of front view videos. The list of side view videos and front view videos should match up", required=True) parser.add_argument( "-d", dest="dltfilename", help= "text file with list of DLTs, one per fly as 'flynum,/path/to/dltfile'", required=True) parser.add_argument( "-body_lbl", dest="bodylabelfilename", help= "text file with list of body-label files, one per fly as 'flynum,/path/to/body_label.lbl'", default=bodylblfile) parser.add_argument("-net", dest="net_name", help="Name of the net to use for tracking", default=default_net_name) parser.add_argument( "-o", dest="outdir", help="temporary output directory to store intermediate computations", required=True) parser.add_argument("-r", dest="redo", help="if specified will recompute everything", action="store_true") parser.add_argument("-rt", dest="redo_tracking", help="if specified will only recompute tracking", action="store_true") parser.add_argument("-gpu", dest='gpunum', type=int, help="GPU to use [optional]") parser.add_argument("-makemovie", dest='makemovie', help="if specified will make results movie", action="store_true") parser.add_argument( "-trackerpath", dest='trackerpath', help= "Absolute path to the compiled MATLAB tracker script run_compute3Dfrom2D.sh", default=defaulttrackerpath) parser.add_argument("-mcrpath", dest='mcrpath', help="Absolute path to MCR", default=defaultmcrpath) parser.add_argument( "-ncores", dest="ncores", help="Number of cores to assign to each MATLAB tracker job", type=int, default=1) group = parser.add_mutually_exclusive_group() group.add_argument( "-only_detect", dest='detect', action="store_true", help="Do only the detection part of tracking which requires GPU") group.add_argument( "-only_track", dest='track', action="store_true", help="Do only the tracking part of the tracking which requires MATLAB") args = parser.parse_args(argv) if args.redo is None: args.redo = False if args.redo_tracking is None: args.redo_tracking = False if args.detect is False and args.track is False: args.detect = True args.track = True args.outdir = os.path.abspath(args.outdir) with open(args.sfilename, "r") as text_file: smovies = text_file.readlines() smovies = [x.rstrip() for x in smovies] with open(args.ffilename, "r") as text_file: fmovies = text_file.readlines() fmovies = [x.rstrip() for x in fmovies] print(smovies) print(fmovies) print(len(smovies)) print(len(fmovies)) if len(smovies) != len(fmovies): print("Side and front movies must match") raise exit(0) if args.track: # read in dltfile dltdict = {} f = open(args.dltfilename, 'r') for l in f: lparts = l.split(',') if len(lparts) != 2: print("Error splitting dlt file line %s into two parts" % l) raise exit(0) dltdict[float(lparts[0])] = lparts[1].strip() f.close() # compiled matlab command matscript = args.trackerpath + " " + args.mcrpath if args.detect: import numpy as np import tensorflow as tf from scipy import io from cvc import cvc import localSetup import PoseTools import multiResData import cv2 import PoseUNet for ff in smovies + fmovies: if not os.path.isfile(ff): print("Movie %s not found" % (ff)) raise exit(0) if args.gpunum is not None: os.environ['CUDA_VISIBLE_DEVICES'] = '0' bodydict = {} f = open(args.bodylabelfilename, 'r') for l in f: lparts = l.split(',') if len(lparts) != 2: print( "Error splitting body label file line %s into two parts" % l) raise exit(0) bodydict[int(lparts[0])] = lparts[1].strip() f.close() for view in range(2): # 0 for front and 1 for side if args.detect: tf.reset_default_graph() conf = apt.create_conf(lbl_file, view=view, name=name, cache_dir=cache_dir, net_type=model_type) update_conf(conf) if view == 0: # from stephenHeadConfig import sideconf as conf extrastr = '_side' valmovies = smovies else: # For FRONT # from stephenHeadConfig import conf as conf extrastr = '_front' valmovies = fmovies if args.detect: for try_num in range(4): try: tf.reset_default_graph() # pred_fn,close_fn,model_file = apt.get_unet_pred_fn(conf) pred_fn, close_fn, model_file = apt.get_pred_fn( model_type=model_type, conf=conf) # self = PoseUNet.PoseUNet(conf, args.net_name) # sess = self.init_net_meta(1) break except ValueError: print('Loading the net failed, retrying') if try_num is 3: raise ValueError( 'Couldnt load the network after 4 tries') for ndx in range(len(valmovies)): mname, _ = os.path.splitext(os.path.basename(valmovies[ndx])) oname = re.sub('!', '__', getexpname(valmovies[ndx])) pname = os.path.join(args.outdir, oname + extrastr) print(oname) # detect if args.detect and os.path.isfile(valmovies[ndx]) and \ (args.redo or not os.path.isfile(pname + '.mat')): cap = cv2.VideoCapture(valmovies[ndx]) height = int(cap.get(cvc.FRAME_HEIGHT)) width = int(cap.get(cvc.FRAME_WIDTH)) cap.release() try: dirname = os.path.normpath(valmovies[ndx]) dir_parts = dirname.split(os.sep) aa = re.search('fly_*(\d+)', dir_parts[-3]) flynum = int(aa.groups()[0]) except AttributeError: print('Could not find the fly number from movie name') print('{} isnt in standard format'.format(smovies[ndx])) continue crop_loc_all = get_crop_locs(bodydict[flynum], view, height, width) # return x first try: predLocs, predScores, pred_ulocs, pred_conf = classify_movie( valmovies[ndx], pred_fn, conf, crop_loc_all) # predList = self.classify_movie(valmovies[ndx], sess, flipud=False) except KeyError: continue # orig_crop_loc = [crop_loc_all[i]-1 for i in (2,0)] # y first # # rescale = conf.rescale # # crop_loc = [int(x/rescale) for x in orig_crop_loc] # # end_pad = [int((height-conf.imsz[0])/rescale)-crop_loc[0],int((width-conf.imsz[1])/rescale)-crop_loc[1]] # # crop_loc = [old_div(x,4) for x in orig_crop_loc] # # end_pad = [old_div(height,4)-crop_loc[0]-old_div(conf.imsz[0],4),old_div(width,4)-crop_loc[1]-old_div(conf.imsz[1],4)] # # pp = [(0,0),(crop_loc[0],end_pad[0]),(crop_loc[1],end_pad[1]),(0,0)] # # predScores = np.pad(predScores,pp,mode='constant',constant_values=-1.) # # predLocs[:,:,0] += orig_crop_loc[1] # x # predLocs[:,:,1] += orig_crop_loc[0] # y hdf5storage.savemat(pname + '.mat', { 'locs': predLocs, 'scores': predScores, 'expname': valmovies[ndx], 'crop_loc': crop_loc_all, 'model_file': model_file, 'ulocs': pred_ulocs, 'pred_conf': pred_conf }, appendmat=False, truncate_existing=True, gzip_compression_level=0) del predScores, predLocs print('Detecting:%s' % oname) # track if args.track and view == 1: oname_side = re.sub('!', '__', getexpname(smovies[ndx])) oname_front = re.sub('!', '__', getexpname(fmovies[ndx])) pname_side = os.path.join(args.outdir, oname_side + '_side.mat') pname_front = os.path.join(args.outdir, oname_front + '_front.mat') # 3d trajectories basename_front, _ = os.path.splitext(fmovies[ndx]) basename_side, _ = os.path.splitext(smovies[ndx]) savefile = basename_side + '_3Dres.mat' #savefile = os.path.join(args.outdir , oname_side + '_3Dres.mat') trkfile_front = basename_front + '.trk' trkfile_side = basename_side + '.trk' redo_tracking = args.redo or args.redo_tracking if os.path.isfile(savefile) and os.path.isfile(trkfile_front) and \ os.path.isfile(trkfile_side) and not redo_tracking: print("%s, %s, and %s exist, skipping tracking" % (savefile, trkfile_front, trkfile_side)) continue try: dirname = os.path.normpath(smovies[ndx]) dir_parts = dirname.split(os.sep) aa = re.search('fly_*(\d+)', dir_parts[-3]) flynum = int(aa.groups()[0]) except AttributeError: print('Could not find the fly number from movie name') print('{} isnt in standard format'.format(smovies[ndx])) continue #print "Parsed fly number as %d"%flynum kinematfile = os.path.abspath(dltdict[flynum]) jobid = oname_side scriptfile = os.path.join(args.outdir, jobid + '_track.sh') logfile = os.path.join(args.outdir, jobid + '_track.log') errfile = os.path.join(args.outdir, jobid + '_track.err') #print "matscript = " + matscript #print "pname_front = " + pname_front #print "pname_side = " + pname_side #print "kinematfile = " + kinematfile # make script to be qsubbed scriptf = open(scriptfile, 'w') scriptf.write('if [ -d %s ]\n' % args.outdir) scriptf.write(' then export MCR_CACHE_ROOT=%s/mcrcache%s\n' % (args.outdir, jobid)) scriptf.write('fi\n') scriptf.write('%s "%s" "%s" "%s" "%s" "%s" "%s"\n' % (matscript, savefile, pname_front, pname_side, kinematfile, trkfile_front, trkfile_side)) scriptf.write('chmod g+w {}\n'.format(savefile)) scriptf.write('chmod g+w {}\n'.format(trkfile_front)) scriptf.write('chmod g+w {}\n'.format(trkfile_side)) scriptf.close() os.chmod( scriptfile, stat.S_IRUSR | stat.S_IRGRP | stat.S_IWUSR | stat.S_IWGRP | stat.S_IXUSR | stat.S_IXGRP) # cmd = "ssh login1 'source /etc/profile; qsub -pe batch %d -N %s -j y -b y -o '%s' -cwd '\"%s\"''"%(args.ncores,jobid,logfile,scriptfile) cmd = "ssh 10.36.11.34 'source /etc/profile; bsub -n %d -J %s -oo '%s' -eo '%s' -cwd . '\"%s\"''" % ( args.ncores, jobid, logfile, errfile, scriptfile) print(cmd) call(cmd, shell=True)
import APT_interface as apt apt.main(args) ## model_file = '/home/mayank/Dropbox (HHMI)/temp/alice/leap/final_model.h5' lbl_file = '/home/mayank/work/poseTF/data/leap/leap_data.lbl' cache_dir = '/home/mayank/work/poseTF/cache/leap_db' import sys import socket import numpy as np import os import APT_interface as apt view = 0 conf = apt.create_conf(lbl_file,0,'leap_db','leap',cache_dir) apt.create_leap_db(conf, False) data_path = os.path.join(cache_dir, 'leap_train.h5') cmd = 'python leap/training_MK.py {}'.format(data_path) print('RUN: {}'.format(cmd)) ## import APT_interface as apt import os import h5py import logging reload(apt) lbl_file = '/home/mayank/work/poseTF/data/stephen/sh_cacheddata_20180717T095200.lbl'
def run_training(lbl_file, cdir, exp_name, data_type, train_type, view, run_type, **kwargs): common_cmd = 'APT_interface.py {} -name {} -cache {}'.format( lbl_file, exp_name, cdir) end_cmd = 'train -skip_db -use_cache' cmd_opts = {} cmd_opts['type'] = train_type cmd_opts['view'] = view + 1 conf_opts = rapt.common_conf.copy() # conf_opts.update(other_conf[conf_id]) conf_opts['save_step'] = conf_opts['dl_steps'] / 10 for k in kwargs.keys(): conf_opts[k] = kwargs[k] if train_type == 'openpose': if data_type == 'bub': conf_opts['op_affinity_graph'] = op_af_graph_bub else: assert False, "define aff graph" # if data_type in ['brit0' ,'brit1','brit2']: # conf_opts['adjust_contrast'] = True # if train_type == 'unet': # conf_opts['batch_size'] = 2 # else: # conf_opts['batch_size'] = 4 # if data_type in ['romain']: # if train_type in ['mdn','resnet_unet']: # conf_opts['batch_size'] = 2 # elif train_type in ['unet']: # conf_opts['batch_size'] = 1 # else: # conf_opts['batch_size'] = 4 # # if data_type in ['larva']: # conf_opts['batch_size'] = 4 # conf_opts['adjust_contrast'] = True # conf_opts['clahe_grid_size'] = 20 # if train_type in ['unet','resnet_unet','leap']: # conf_opts['rescale'] = 2 # conf_opts['batch_size'] = 2 # if train_type in ['mdn']: # conf_opts['batch_size'] = 4 # conf_opts['rescale'] = 2 # conf_opts['mdn_use_unet_loss'] = True # # conf_opts['mdn_learning_rate'] = 0.0001 # # if data_type == 'stephen': # conf_opts['batch_size'] = 4 # if data_type == 'carsen': # if train_type in ['mdn','unet','resnet_unet']: # conf_opts['rescale'] = 2. # else: # conf_opts['rescale'] = 1. # conf_opts['adjust_contrast'] = True # conf_opts['clahe_grid_size'] = 20 # if train_type in ['unet']: # conf_opts['batch_size'] = 4 # else: # conf_opts['batch_size'] = 8 # # if op_af_graph is not None: # conf_opts['op_affinity_graph'] = op_af_graph if len(conf_opts) > 0: conf_str = ' -conf_params' for k in conf_opts.keys(): conf_str = '{} {} {} '.format(conf_str, k, conf_opts[k]) else: conf_str = '' opt_str = '' for k in cmd_opts.keys(): opt_str = '{} -{} {} '.format(opt_str, k, cmd_opts[k]) cur_cmd = common_cmd + conf_str + opt_str + end_cmd cmd_name = '{}_view{}_{}_{}'.format(data_type, view, exp_name, train_type) if run_type == 'dry': print cmd_name print cur_cmd print elif run_type == 'submit': print cmd_name print cur_cmd print run_jobs(cmd_name, cur_cmd) elif run_type == 'status': conf = apt.create_conf(lbl_file, view, exp_name, cdir, train_type) check_train_status(cmd_name, conf.cachedir)
import APT_interface as apt import os # Alice's dataset name = 'alice' val_ratio = 0.1 lbl_file = '/home/kabram/Dropbox (HHMI)/temp/multitarget_bubble_expandedbehavior_20180425_FxdErrs_OptoParams20181126_dlstripped.lbl' nviews = 1 for view in range(nviews): conf = apt.create_conf(lbl_file,0,'tfds','/home/kabram/temp','mdn') conf.cachedir = '/home/kabram/temp/tfds_{}_view{}'.format(name,view) conf.valratio = val_ratio os.makedirs(conf.cachedir,exist_ok=True) apt.create_tfrecord(conf, split=True, split_file=None, use_cache=True, on_gt=False)
common_conf['batch_size'] = 8 common_conf['maxckpt'] = 20 cache_dir = '/nrs/branson/mayank/apt_cache' train_name = 'deepnet' assert gt_lbl is None all_view = [] for view in range(nviews): out_exp = {} for tndx in range(len(all_models)): train_type = all_models[tndx] out_split = None for split in range(n_splits): exp_name = 'cv_split_{}'.format(split) mdn_conf = apt.create_conf(lbl_file, view, exp_name, cache_dir, 'mdn') conf = apt.create_conf(lbl_file, view, exp_name, cache_dir, train_type) if op_af_graph is not None: conf.op_affinity_graph = ast.literal_eval( op_af_graph.replace('\\', '')) files = glob.glob( os.path.join(conf.cachedir, "{}-[0-9]*").format(train_name)) files.sort(key=os.path.getmtime) files = [ f for f in files if os.path.splitext(f)[1] in ['.index', ''] ] aa = [int(re.search('-(\d*)', f).groups(0)[0]) for f in files] aa = [b - a for a, b in zip(aa[:-1], aa[1:])] if any([a < 0 for a in aa]):
out_leap = [[preds,labels,[],[],[],[]]] dd_leap = np.sqrt(np.sum((labels-preds)**2,1)) dd_leap = dd_leap.T cache_dir = '/nrs/branson/mayank/apt_cache' exp_name = 'apt_expt' train_name = 'deepnet' gt_file = os.path.join(cache_dir, rae.proj_name, 'gtdata', 'gtdata_view{}{}.tfrecords'.format(view, rae.gt_name)) H = multiResData.read_and_decode_without_session(gt_file, 32) ex_im = np.array(H[0][0])[:, :, 0] ex_loc = np.array(H[1][0]) our_res = pt.pickle_load('/nrs/branson/mayank/apt_cache/leap_dset/leap/view_0/apt_expt/deepnet_results.p') our_preds = our_res[0][-1][0] our_labels = our_res[0][-1][1] conf = apt.create_conf(rae.lbl_file,0,'apt_expt',cache_dir,'leap') orig_leap_models = ['/nrs/branson/mayank/apt_cache/leap_dset/leap/view_0/apt_expt/weights-045.h5',] orig_leap = apt_expts.classify_db_all(conf,gt_file,orig_leap_models,'leap',name=train_name) out_dict = {'leap':out_leap,'our leap':our_res[0],'leap_orig':orig_leap} rae.plot_hist([out_dict,ex_im,ex_loc]) ## mdn with and without unet import run_apt_expts as rae rae.setup('alice') rae.run_mdn_no_unet() ## import run_apt_expts as rae rae.setup('alice')
## Setup import APT_interface as apt import RNN_postprocess import tensorflow as tf import os import easydict os.environ['CUDA_VISIBLE_DEVICES'] = '0' tf.reset_default_graph() exp_name = 'postprocess' view = 0 mdn_name = 'deepnet' lbl_file = '/groups/branson/bransonlab/apt/experiments/data/sh_trn4992_gtcomplete_cacheddata_updatedAndPpdbManuallyCopied20190402_dlstripped.lbl' conf = apt.create_conf(lbl_file, view, exp_name, '/nrs/branson/mayank/apt_cache', 'mdn') conf.n_steps = 2 conf.rrange = 30 conf.trange = 10 conf.mdn_use_unet_loss = True conf.dl_steps = 40000 conf.decay_steps = 20000 conf.save_step = 5000 conf.batch_size = 8 conf.normalize_img_mean = False conf.maxckpt = 20 ## Train MDN args = easydict.EasyDict args.skip_db = False
url_lib.urlretrieve(gt_file_url, gt_file) res_file_url = 'https://www.dropbox.com/s/cr702321rvv3htl/alice_view0_time.mat?dl=1' res_file = os.path.join(tdir,'alice_view0_time.mat') url_lib.urlretrieve(res_file_url,res_file) cmd = '-cache {} -name {} -conf_params batch_size {} dl_steps {} op_affinity_graph {} -type {{}} {} train -use_cache '.format(tdir, exp_name, bsz, dl_steps,op_af_graph, lbl_file) ## import h5py R = h5py.File(res_file,'r') for net in net_types: apt.main(cmd.format(net).split()) conf = apt.create_conf(lbl_file, 0, exp_name, tdir, net) # if data_type == 'stephen' and train_type == 'mdn': # conf.mdn_use_unet_loss = False if op_af_graph is not None: conf.op_affinity_graph = ast.literal_eval(op_af_graph.replace('\\', '')) files = glob.glob(os.path.join(conf.cachedir, "{}-[0-9]*").format('deepnet')) files.sort(key=os.path.getmtime) files = [f for f in files if os.path.splitext(f)[1] in ['.index', '']] aa = [int(re.search('-(\d*)', f).groups(0)[0]) for f in files] aa = [b - a for a, b in zip(aa[:-1], aa[1:])] if any([a < 0 for a in aa]): bb = int(np.where(np.array(aa) < 0)[0]) + 1 files = files[bb:] files = files[-1:] # n_max = 10 # if len(files)> n_max:
def compute_peformance(args): H = h5py.File(args.lbl_file, 'r') nviews = int(apt.read_entry(H['cfg']['NumViews'])) dir_name = args.whose if len(args.nets) == 0: all_nets = methods else: all_nets = args.nets all_preds = {} for view in range(nviews): db_file = os.path.join(out_dir, args.name, args.gt_name) + '_view{}.tfrecords'.format(view) conf = apt.create_conf(args.lbl_file, view, name='a', net_type=all_nets[0], cache_dir=os.path.join(out_dir, args.name, dir_name)) conf.labelfile = args.gt_lbl if not (os.path.exists(db_file) and args.skip_gt_db): print('Creating GT DB file {}'.format(db_file)) apt.create_tfrecord(conf, split=False, on_gt=True, db_files=(db_file, )) for curm in all_nets: all_preds[curm] = [] for view in range(nviews): cur_out = [] db_file = os.path.join( out_dir, args.name, args.gt_name) + '_view{}.tfrecords'.format(view) if args.split_type is None: cachedir = os.path.join(out_dir, args.name, dir_name, '{}_view_{}'.format(curm, view), 'full') conf = apt.create_conf(args.lbl_file, view, name='a', net_type=curm, cache_dir=cachedir) model_files, ts = get_model_files(conf, cachedir, curm) for mndx, m in enumerate(model_files): out_file = m + '_' + args.gt_name load = False if curm == 'unet' or curm == 'deeplabcut': mm = m + '.index' else: mm = m if os.path.exists(out_file + '.mat') and os.path.getmtime( out_file + '.mat') > os.path.getmtime(mm): load = True if load: H = sio.loadmat(out_file) pred = H['pred_locs'] - 1 label = H['labeled_locs'] - 1 gt_list = H['list'] - 1 else: # pred, label, gt_list = apt.classify_gt_data(conf, curm, out_file, m) tf_iterator = multiResData.tf_reader( conf, db_file, False) tf_iterator.batch_size = 1 read_fn = tf_iterator.next pred_fn, close_fn, _ = apt.get_pred_fn(curm, conf, m) pred, label, gt_list = apt.classify_db( conf, read_fn, pred_fn, tf_iterator.N) close_fn() mat_pred_locs = pred + 1 mat_labeled_locs = np.array(label) + 1 mat_list = gt_list sio.savemat( out_file, { 'pred_locs': mat_pred_locs, 'labeled_locs': mat_labeled_locs, 'list': mat_list }) cur_out.append( [pred, label, gt_list, m, out_file, view, 0, ts[mndx]]) else: for cur_split in range(nsplits): cachedir = os.path.join(out_dir, args.name, '{}_view_{}'.format(curm, view), 'cv_{}'.format(cur_split)) conf = apt.create_conf(args.lbl_file, view, name='a', net_type=curm, cache_dir=cachedir) model_files, ts = get_model_files(conf, cachedir, curm) db_file = os.path.join(cachedir, 'val_TF.tfrecords') for mndx, m in enumerate(model_files): out_file = m + '.gt_data' load = False if curm == 'unet' or curm == 'deeplabcut': mm = m + '.index' else: mm = m if os.path.exists( out_file + '.mat') and os.path.getmtime( out_file + '.mat') > os.path.getmtime(mm): load = True if load: H = sio.loadmat(out_file) pred = H['pred_locs'] - 1 label = H['labeled_locs'] - 1 gt_list = H['list'] - 1 else: tf_iterator = multiResData.tf_reader( conf, db_file, False) tf_iterator.batch_size = 1 read_fn = tf_iterator.next pred_fn, close_fn, _ = apt.get_pred_fn( curm, conf, m) pred, label, gt_list = apt.classify_db( conf, read_fn, pred_fn, tf_iterator.N) close_fn() mat_pred_locs = pred + 1 mat_labeled_locs = np.array(label) + 1 mat_list = gt_list sio.savemat( out_file, { 'pred_locs': mat_pred_locs, 'labeled_locs': mat_labeled_locs, 'list': mat_list }) cur_out.append([ pred, label, gt_list, m, out_file, view, cur_split, ts[mndx] ]) all_preds[curm].append(cur_out) with open( os.path.join(out_dir, args.name, dir_name, args.gt_name + '_results.p'), 'w') as f: pickle.dump(all_preds, f)
def create_db(args): H = h5py.File(args.lbl_file, 'r') nviews = int(apt.read_entry(H['cfg']['NumViews'])) all_nets = args.nets all_split_files = [] for view in range(nviews): if args.split_type is not None and not args.split_type.startswith( 'prog'): cachedir = os.path.join(out_dir, args.name, 'common') if not os.path.exists(cachedir): os.mkdir(cachedir) cachedir = os.path.join(out_dir, args.name, 'common', 'splits_{}'.format(view)) if not os.path.exists(cachedir): os.mkdir(cachedir) conf = apt.create_conf(args.lbl_file, view, args.name, cache_dir=cachedir) conf.splitType = args.split_type print("Split type is {}".format(conf.splitType)) if args.do_split: train_info, val_info, split_files = apt.create_cv_split_files( conf, nsplits) else: split_files = [ os.path.join(conf.cachedir, 'cv_split_fold_{}.json'.format(ndx)) for ndx in range(nsplits) ] all_split_files.append(split_files) for curm in all_nets: for view in range(nviews): if args.split_type is None: cachedir = os.path.join(out_dir, args.name, 'common', '{}_view_{}'.format(curm, view), 'full') conf = apt.create_conf(args.lbl_file, view, args.name, cache_dir=cachedir) if not args.only_check: if not os.path.exists(conf.cachedir): os.makedirs(conf.cachedir) if curm == 'unet' or curm == 'openpose': apt.create_tfrecord(conf, False) elif curm == 'leap': apt.create_leap_db(conf, False) elif curm == 'deeplabcut': apt.create_deepcut_db(conf, False) create_deepcut_cfg(conf) else: raise ValueError('Undefined net type: {}'.format(curm)) check_db(curm, conf) elif args.split_type.startswith('prog'): split_type = args.split_type[5:] all_info = get_increasing_splits(conf, split_type) for cur_tr in prog_thresholds: cachedir = os.path.join(out_dir, args.name, 'common', '{}_view_{}'.format(curm, view), '{}'.format(cur_tr)) conf = apt.create_conf(args.lbl_file, view, args.name, cache_dir=cachedir) split_ndx = round(len(all_info) / cur_tr) cur_train = all_info[:split_ndx] cur_val = all_info[split_ndx:] split_file = os.path.join(cachedir, 'splitdata.json') with open(split_file, 'w') as f: json.dump([cur_train, cur_val], f) if not args.only_check: if curm == 'unet' or curm == 'openpose': apt.create_tfrecord(conf, True, split_file) elif curm == 'leap': apt.create_leap_db(conf, True, split_file) elif curm == 'deeplabcut': apt.create_deepcut_db(conf, True, split_file) create_deepcut_cfg(conf) else: raise ValueError( 'Undefined net type: {}'.format(curm)) check_db(curm, conf) else: split_files = all_split_files[view] for cur_split in range(nsplits): conf.cachedir = os.path.join( out_dir, args.name, 'common', '{}_view_{}'.format(curm, view)) if not os.path.exists(conf.cachedir): os.mkdir(conf.cachedir) conf.cachedir = os.path.join( out_dir, args.name, 'common', '{}_view_{}'.format(curm, view), 'cv_{}'.format(cur_split)) if not os.path.exists(conf.cachedir): os.mkdir(conf.cachedir) conf.splitType = 'predefined' split_file = split_files[cur_split] if not args.only_check: if curm == 'unet' or curm == 'openpose': apt.create_tfrecord(conf, True, split_file) elif curm == 'leap': apt.create_leap_db(conf, True, split_file) elif curm == 'deeplabcut': apt.create_deepcut_db(conf, True, split_file) create_deepcut_cfg(conf) else: raise ValueError( 'Undefined net type: {}'.format(curm)) check_db(curm, conf) base_dir = os.path.join(out_dir, args.name, 'common') their_dir = os.path.join(out_dir, args.name, 'theirs') our_dir = os.path.join(out_dir, args.name, 'ours') our_default_dir = os.path.join(out_dir, args.name, 'ours_default') cmd = 'cp -rs {} {}'.format(base_dir, their_dir) os.system(cmd) cmd = 'cp -rs {} {}'.format(base_dir, our_dir) os.system(cmd) cmd = 'cp -rs {} {}'.format(base_dir, our_default_dir) os.system(cmd)