Exemple #1
0
def cv_train_from_mat(lbl_file,
                      cdir,
                      cv_info_file,
                      models_run,
                      view=0,
                      skip_db=False,
                      create_splits=True,
                      dorun=False,
                      run_type='status'):

    cv_info, in_info, label_info = read_cvinfo(lbl_file, cdir, cv_info_file,
                                               view)

    lbl = h5py.File(lbl_file, 'r')
    proj_name = apt.read_string(lbl['projname'])
    lbl.close()

    cvifileshort = os.path.basename(cv_info_file)
    cvifileshort = os.path.splitext(cvifileshort)[0]

    n_splits = max(cv_info) + 1

    print("{} splits, {} rows in cvi, {} rows in lbl, projname {}".format(
        n_splits, len(cv_info), len(label_info), proj_name))

    for sndx in range(n_splits):
        val_info = [l for ndx, l in enumerate(in_info) if cv_info[ndx] == sndx]
        trn_info = list(set(label_info) - set(val_info))
        cur_split = [trn_info, val_info]
        exp_name = '{:s}__split{}'.format(cvifileshort, sndx)
        split_file = os.path.join(cdir, proj_name, exp_name) + '.json'
        if not skip_db and create_splits:
            assert not os.path.exists(split_file)
            with open(split_file, 'w') as f:
                json.dump(cur_split, f)

        # create the dbs
        if not skip_db:
            for train_type in models_run:
                conf = apt.create_conf(lbl_file, view, exp_name, cdir,
                                       train_type)
                conf.splitType = 'predefined'
                if train_type == 'deeplabcut':
                    apt.create_deepcut_db(conf,
                                          split=True,
                                          split_file=split_file,
                                          use_cache=True)
                elif train_type == 'leap':
                    apt.create_leap_db(conf,
                                       split=True,
                                       split_file=split_file,
                                       use_cache=True)
                else:
                    apt.create_tfrecord(conf,
                                        split=True,
                                        split_file=split_file,
                                        use_cache=True)
        if dorun:
            for train_type in models_run:
                rapt.run_trainining(elblbubxp_name, train_type, view, run_type)
Exemple #2
0
def read_cvinfo(lbl_file, cdir, cv_info_file, view=0):

    conf = apt.create_conf(lbl_file, view, 'cv_dummy', cdir,
                           'mdn')  # net type irrelevant
    #lbl_movies, _ = multiResData.find_local_dirs(conf)
    #in_movies = [PoseTools.read_h5_str(data_info[k]) for k in data_info['movies'][0, :]]
    #assert lbl_movies == in_movies
    label_info = rapt.get_label_info(conf)

    cvi = h5py.File(cv_info_file, 'r')

    cv_info = apt.to_py(cvi['cvi'].value[:, 0].astype('int'))
    fr_info = apt.to_py(cvi['frame'].value[:, 0].astype('int'))
    m_info = apt.to_py(cvi['movieidx'].value[:, 0].astype('int'))
    if 'target' in cvi.keys():
        t_info = apt.to_py(cvi['target'].value[:, 0].astype('int'))
        in_info = [(a, b, c) for a, b, c in zip(m_info, fr_info, t_info)]
    else:
        in_info = [(a, b, 0) for a, b in zip(m_info, fr_info)]
    diff1 = list(set(label_info) - set(in_info))
    diff2 = list(set(in_info) - set(label_info))
    print('Number of labels that exists in label file but not in mat file:{}'.
          format(len(diff1)))
    print('Number of labels that exists in mat file but not in label file:{}'.
          format(len(diff2)))
    # assert all([a == b for a, b in zip(in_info, label_info)])

    return cv_info, in_info, label_info
Exemple #3
0
def test_crop():
    import trackStephenHead_KB as ts
    import APT_interface as apt
    import multiResData
    import cv2
    from cvc import cvc
    import os
    import re
    import hdf5storage
    crop_reg_file = '/groups/branson/bransonlab/mayank/stephen_copy/crop_regression_params.mat'
    # lbl_file = '/groups/branson/bransonlab/mayank/stephen_copy/apt_cache/sh_trn4523_gtcomplete_cacheddata_bestPrms20180920_retrain20180920T123534_withGTres.lbl'
    lbl_file = '/groups/branson/bransonlab/mayank/stephen_copy/apt_cache/sh_trn4879_gtcomplete.lbl'
    crop_size = [[230, 350], [350, 350]]
    name = 'stephen_20181029'
    cache_dir = '/groups/branson/bransonlab/mayank/stephen_copy/apt_cache'
    bodylblfile = '/groups/branson/bransonlab/mayank/stephen_copy/fly2BodyAxis_lookupTable_Ben.csv'
    import h5py
    bodydict = {}
    f = open(bodylblfile, 'r')
    for l in f:
        lparts = l.split(',')
        if len(lparts) != 2:
            print("Error splitting body label file line %s into two parts" % l)
            raise exit(0)
        bodydict[int(lparts[0])] = lparts[1].strip()
    f.close()

    flynums = [[], []]
    crop_locs = [[], []]
    for view in range(2):
        conf = apt.create_conf(lbl_file,
                               view,
                               'aa',
                               cache_dir='/groups/branson/home/kabram/temp')
        movs = multiResData.find_local_dirs(conf)[0]

        for mov in movs:
            dirname = os.path.normpath(mov)
            dir_parts = dirname.split(os.sep)
            aa = re.search('fly_*(\d+)', dir_parts[-3])
            flynum = int(aa.groups()[0])
            if bodydict.has_key(flynum):
                cap = cv2.VideoCapture(mov)
                height = int(cap.get(cvc.FRAME_HEIGHT))
                width = int(cap.get(cvc.FRAME_WIDTH))
                cap.release()
                crop_locs[view].append(
                    ts.get_crop_locs(bodydict[flynum], view, height,
                                     width))  # return x first
                flynums[view].append(flynum)

    hdf5storage.savemat(
        '/groups/branson/bransonlab/mayank/stephen_copy/auto_crop_locs_trn4879',
        {
            'flynum': flynums,
            'crop_locs': crop_locs
        })
Exemple #4
0
def save_cv_results(lbl_file,
                    cachedir,
                    view,
                    exp_name,
                    net,
                    model_file_short,
                    out_dir,
                    data_type,
                    kwout,
                    mdn_hm_floor=0.1,
                    db_file=None):

    conf_pvlist = None
    if net == 'openpose':
        if data_type == 'bub':
            conf_pvlist = ['op_affinity_graph', op_af_graph_bub_noslash]
        else:
            assert False, "define aff graph"

    return_hmaps = (net == 'mdn')

    conf = apt.create_conf(lbl_file,
                           view,
                           exp_name,
                           cachedir,
                           net,
                           conf_params=conf_pvlist)
    if db_file is None:
        db_file = os.path.join(conf.cachedir, 'val_TF.tfrecords')
    model_file = os.path.join(conf.cachedir, model_file_short)
    res = apt_expts.classify_db_all(conf,
                                    db_file, [model_file],
                                    net,
                                    return_hm=return_hmaps,
                                    hm_dec=1,
                                    hm_floor=mdn_hm_floor,
                                    hm_nclustermax=1)

    res.append(conf)
    out_file = "{}__vw{}__{}__{}.p".format(exp_name, view, net, kwout)
    out_file = os.path.join(out_dir, out_file)
    with open(out_file, 'w') as f:
        pickle.dump(res, f)
    print "saved {}".format(out_file)
Exemple #5
0
def train():
    import PoseUNet_resnet as PoseURes
    import tensorflow as tf

    dstr = PoseTools.datestr()
    cur_name = 'stephen_{}'.format(dstr)

    for view in range(2):
        conf = apt.create_conf(lbl_file,
                               view=view,
                               name=cur_name,
                               cache_dir=cache_dir,
                               net_type=model_type)
        update_conf(conf)
        apt.create_tfrecord(conf, False, use_cache=True)
        tf.reset_default_graph()
        self = PoseURes.PoseUMDN_resnet(conf, name='deepnet')
        self.train_data_name = 'traindata'
        self.train_umdn()
Exemple #6
0
def predsingle(lbl_file, cachedir, view, exp_name, net, model_file_short,
               data_type):
    conf_pvlist = None
    if net == 'openpose':
        if data_type == 'bub':
            conf_pvlist = ['op_affinity_graph', op_af_graph_bub_noslash]
        else:
            assert False, "define aff graph"

    conf = apt.create_conf(lbl_file,
                           view,
                           exp_name,
                           cachedir,
                           net,
                           conf_params=conf_pvlist)
    db_file = os.path.join(conf.cachedir, 'val_TF.tfrecords')
    model_file = os.path.join(conf.cachedir, model_file_short)

    extra_str = ''
    if net not in ['leap', 'openpose']:
        extra_str = '.index'

    tf_iterator = multiResData.tf_reader(conf, db_file, False)
    tf_iterator.batch_size = 1
    read_fn = tf_iterator.next
    pred_fn, close_fn, _ = op.get_pred_fn(conf,
                                          model_file,
                                          name=None,
                                          rawpred=True)

    im, locs, info, _ = read_fn()
    print "im.shape is {}".format(im.shape)
    predmaps = pred_fn(im)
    close_fn()

    return predmaps, im, locs, info
Exemple #7
0
def main(argv):

    parser = argparse.ArgumentParser()
    parser.add_argument("-s",
                        dest="sfilename",
                        help="text file with list of side view videos",
                        required=True)
    parser.add_argument(
        "-f",
        dest="ffilename",
        help=
        "text file with list of front view videos. The list of side view videos and front view videos should match up",
        required=True)
    parser.add_argument(
        "-d",
        dest="dltfilename",
        help=
        "text file with list of DLTs, one per fly as 'flynum,/path/to/dltfile'",
        required=True)
    parser.add_argument(
        "-body_lbl",
        dest="bodylabelfilename",
        help=
        "text file with list of body-label files, one per fly as 'flynum,/path/to/body_label.lbl'",
        default=bodylblfile)
    parser.add_argument("-net",
                        dest="net_name",
                        help="Name of the net to use for tracking",
                        default=default_net_name)
    parser.add_argument(
        "-o",
        dest="outdir",
        help="temporary output directory to store intermediate computations",
        required=True)
    parser.add_argument("-r",
                        dest="redo",
                        help="if specified will recompute everything",
                        action="store_true")
    parser.add_argument("-rt",
                        dest="redo_tracking",
                        help="if specified will only recompute tracking",
                        action="store_true")
    parser.add_argument("-gpu",
                        dest='gpunum',
                        type=int,
                        help="GPU to use [optional]")
    parser.add_argument("-makemovie",
                        dest='makemovie',
                        help="if specified will make results movie",
                        action="store_true")
    parser.add_argument(
        "-trackerpath",
        dest='trackerpath',
        help=
        "Absolute path to the compiled MATLAB tracker script run_compute3Dfrom2D.sh",
        default=defaulttrackerpath)
    parser.add_argument("-mcrpath",
                        dest='mcrpath',
                        help="Absolute path to MCR",
                        default=defaultmcrpath)
    parser.add_argument(
        "-ncores",
        dest="ncores",
        help="Number of cores to assign to each MATLAB tracker job",
        type=int,
        default=1)

    group = parser.add_mutually_exclusive_group()
    group.add_argument(
        "-only_detect",
        dest='detect',
        action="store_true",
        help="Do only the detection part of tracking which requires GPU")
    group.add_argument(
        "-only_track",
        dest='track',
        action="store_true",
        help="Do only the tracking part of the tracking which requires MATLAB")

    args = parser.parse_args(argv)
    if args.redo is None:
        args.redo = False
    if args.redo_tracking is None:
        args.redo_tracking = False

    if args.detect is False and args.track is False:
        args.detect = True
        args.track = True

    args.outdir = os.path.abspath(args.outdir)

    with open(args.sfilename, "r") as text_file:
        smovies = text_file.readlines()
    smovies = [x.rstrip() for x in smovies]
    with open(args.ffilename, "r") as text_file:
        fmovies = text_file.readlines()
    fmovies = [x.rstrip() for x in fmovies]

    print(smovies)
    print(fmovies)
    print(len(smovies))
    print(len(fmovies))

    if len(smovies) != len(fmovies):
        print("Side and front movies must match")
        raise exit(0)

    if args.track:
        # read in dltfile
        dltdict = {}
        f = open(args.dltfilename, 'r')
        for l in f:
            lparts = l.split(',')
            if len(lparts) != 2:
                print("Error splitting dlt file line %s into two parts" % l)
                raise exit(0)
            dltdict[float(lparts[0])] = lparts[1].strip()
        f.close()

        # compiled matlab command
        matscript = args.trackerpath + " " + args.mcrpath

    if args.detect:
        import numpy as np
        import tensorflow as tf
        from scipy import io
        from cvc import cvc
        import localSetup
        import PoseTools
        import multiResData
        import cv2
        import PoseUNet

        for ff in smovies + fmovies:
            if not os.path.isfile(ff):
                print("Movie %s not found" % (ff))
                raise exit(0)
        if args.gpunum is not None:
            os.environ['CUDA_VISIBLE_DEVICES'] = '0'

        bodydict = {}
        f = open(args.bodylabelfilename, 'r')
        for l in f:
            lparts = l.split(',')
            if len(lparts) != 2:
                print(
                    "Error splitting body label file line %s into two parts" %
                    l)
                raise exit(0)
            bodydict[int(lparts[0])] = lparts[1].strip()
        f.close()

    for view in range(2):  # 0 for front and 1 for side
        if args.detect:
            tf.reset_default_graph()
        conf = apt.create_conf(lbl_file,
                               view=view,
                               name=name,
                               cache_dir=cache_dir,
                               net_type=model_type)
        update_conf(conf)
        if view == 0:
            # from stephenHeadConfig import sideconf as conf
            extrastr = '_side'
            valmovies = smovies
        else:
            # For FRONT
            # from stephenHeadConfig import conf as conf
            extrastr = '_front'
            valmovies = fmovies

        if args.detect:
            for try_num in range(4):
                try:
                    tf.reset_default_graph()
                    # pred_fn,close_fn,model_file = apt.get_unet_pred_fn(conf)
                    pred_fn, close_fn, model_file = apt.get_pred_fn(
                        model_type=model_type, conf=conf)
                    # self = PoseUNet.PoseUNet(conf, args.net_name)
                    # sess = self.init_net_meta(1)
                    break
                except ValueError:
                    print('Loading the net failed, retrying')
                    if try_num is 3:
                        raise ValueError(
                            'Couldnt load the network after 4 tries')

        for ndx in range(len(valmovies)):
            mname, _ = os.path.splitext(os.path.basename(valmovies[ndx]))
            oname = re.sub('!', '__', getexpname(valmovies[ndx]))
            pname = os.path.join(args.outdir, oname + extrastr)

            print(oname)

            # detect
            if args.detect and os.path.isfile(valmovies[ndx]) and \
               (args.redo or not os.path.isfile(pname + '.mat')):

                cap = cv2.VideoCapture(valmovies[ndx])
                height = int(cap.get(cvc.FRAME_HEIGHT))
                width = int(cap.get(cvc.FRAME_WIDTH))
                cap.release()
                try:
                    dirname = os.path.normpath(valmovies[ndx])
                    dir_parts = dirname.split(os.sep)
                    aa = re.search('fly_*(\d+)', dir_parts[-3])
                    flynum = int(aa.groups()[0])
                except AttributeError:
                    print('Could not find the fly number from movie name')
                    print('{} isnt in standard format'.format(smovies[ndx]))
                    continue
                crop_loc_all = get_crop_locs(bodydict[flynum], view, height,
                                             width)  # return x first
                try:
                    predLocs, predScores, pred_ulocs, pred_conf = classify_movie(
                        valmovies[ndx], pred_fn, conf, crop_loc_all)
                    # predList = self.classify_movie(valmovies[ndx], sess, flipud=False)
                except KeyError:
                    continue


#                 orig_crop_loc = [crop_loc_all[i]-1 for i in (2,0)] # y first
#                 # rescale = conf.rescale
#                 # crop_loc = [int(x/rescale) for x in orig_crop_loc]
#                 # end_pad = [int((height-conf.imsz[0])/rescale)-crop_loc[0],int((width-conf.imsz[1])/rescale)-crop_loc[1]]
# #                crop_loc = [old_div(x,4) for x in orig_crop_loc]
# #                end_pad = [old_div(height,4)-crop_loc[0]-old_div(conf.imsz[0],4),old_div(width,4)-crop_loc[1]-old_div(conf.imsz[1],4)]
# #                 pp = [(0,0),(crop_loc[0],end_pad[0]),(crop_loc[1],end_pad[1]),(0,0)]
# #                 predScores = np.pad(predScores,pp,mode='constant',constant_values=-1.)
#
#                 predLocs[:,:,0] += orig_crop_loc[1] # x
#                 predLocs[:,:,1] += orig_crop_loc[0] # y

                hdf5storage.savemat(pname + '.mat', {
                    'locs': predLocs,
                    'scores': predScores,
                    'expname': valmovies[ndx],
                    'crop_loc': crop_loc_all,
                    'model_file': model_file,
                    'ulocs': pred_ulocs,
                    'pred_conf': pred_conf
                },
                                    appendmat=False,
                                    truncate_existing=True,
                                    gzip_compression_level=0)
                del predScores, predLocs

                print('Detecting:%s' % oname)

            # track
            if args.track and view == 1:

                oname_side = re.sub('!', '__', getexpname(smovies[ndx]))
                oname_front = re.sub('!', '__', getexpname(fmovies[ndx]))
                pname_side = os.path.join(args.outdir,
                                          oname_side + '_side.mat')
                pname_front = os.path.join(args.outdir,
                                           oname_front + '_front.mat')
                # 3d trajectories
                basename_front, _ = os.path.splitext(fmovies[ndx])
                basename_side, _ = os.path.splitext(smovies[ndx])
                savefile = basename_side + '_3Dres.mat'
                #savefile = os.path.join(args.outdir , oname_side + '_3Dres.mat')
                trkfile_front = basename_front + '.trk'
                trkfile_side = basename_side + '.trk'

                redo_tracking = args.redo or args.redo_tracking
                if os.path.isfile(savefile) and os.path.isfile(trkfile_front) and \
                   os.path.isfile(trkfile_side) and not redo_tracking:
                    print("%s, %s, and %s exist, skipping tracking" %
                          (savefile, trkfile_front, trkfile_side))
                    continue

                try:
                    dirname = os.path.normpath(smovies[ndx])
                    dir_parts = dirname.split(os.sep)
                    aa = re.search('fly_*(\d+)', dir_parts[-3])
                    flynum = int(aa.groups()[0])
                except AttributeError:
                    print('Could not find the fly number from movie name')
                    print('{} isnt in standard format'.format(smovies[ndx]))
                    continue
                #print "Parsed fly number as %d"%flynum
                kinematfile = os.path.abspath(dltdict[flynum])

                jobid = oname_side

                scriptfile = os.path.join(args.outdir, jobid + '_track.sh')
                logfile = os.path.join(args.outdir, jobid + '_track.log')
                errfile = os.path.join(args.outdir, jobid + '_track.err')

                #print "matscript = " + matscript
                #print "pname_front = " + pname_front
                #print "pname_side = " + pname_side
                #print "kinematfile = " + kinematfile

                # make script to be qsubbed
                scriptf = open(scriptfile, 'w')
                scriptf.write('if [ -d %s ]\n' % args.outdir)
                scriptf.write('  then export MCR_CACHE_ROOT=%s/mcrcache%s\n' %
                              (args.outdir, jobid))
                scriptf.write('fi\n')
                scriptf.write('%s "%s" "%s" "%s" "%s" "%s" "%s"\n' %
                              (matscript, savefile, pname_front, pname_side,
                               kinematfile, trkfile_front, trkfile_side))
                scriptf.write('chmod g+w {}\n'.format(savefile))
                scriptf.write('chmod g+w {}\n'.format(trkfile_front))
                scriptf.write('chmod g+w {}\n'.format(trkfile_side))
                scriptf.close()
                os.chmod(
                    scriptfile, stat.S_IRUSR | stat.S_IRGRP | stat.S_IWUSR
                    | stat.S_IWGRP | stat.S_IXUSR | stat.S_IXGRP)

                #                cmd = "ssh login1 'source /etc/profile; qsub -pe batch %d -N %s -j y -b y -o '%s' -cwd '\"%s\"''"%(args.ncores,jobid,logfile,scriptfile)
                cmd = "ssh 10.36.11.34 'source /etc/profile; bsub -n %d -J %s -oo '%s' -eo '%s' -cwd . '\"%s\"''" % (
                    args.ncores, jobid, logfile, errfile, scriptfile)
                print(cmd)
                call(cmd, shell=True)
Exemple #8
0
import APT_interface as apt
apt.main(args)

##
model_file = '/home/mayank/Dropbox (HHMI)/temp/alice/leap/final_model.h5'
lbl_file = '/home/mayank/work/poseTF/data/leap/leap_data.lbl'
cache_dir = '/home/mayank/work/poseTF/cache/leap_db'

import sys
import socket
import  numpy as np
import os

import APT_interface as apt
view = 0
conf = apt.create_conf(lbl_file,0,'leap_db','leap',cache_dir)
apt.create_leap_db(conf, False)

data_path = os.path.join(cache_dir, 'leap_train.h5')
cmd = 'python leap/training_MK.py {}'.format(data_path)
print('RUN: {}'.format(cmd))


##
import APT_interface as apt
import os
import h5py
import logging
reload(apt)

lbl_file = '/home/mayank/work/poseTF/data/stephen/sh_cacheddata_20180717T095200.lbl'
Exemple #9
0
def run_training(lbl_file, cdir, exp_name, data_type, train_type, view,
                 run_type, **kwargs):

    common_cmd = 'APT_interface.py {} -name {} -cache {}'.format(
        lbl_file, exp_name, cdir)
    end_cmd = 'train -skip_db -use_cache'

    cmd_opts = {}
    cmd_opts['type'] = train_type
    cmd_opts['view'] = view + 1

    conf_opts = rapt.common_conf.copy()
    # conf_opts.update(other_conf[conf_id])
    conf_opts['save_step'] = conf_opts['dl_steps'] / 10
    for k in kwargs.keys():
        conf_opts[k] = kwargs[k]

    if train_type == 'openpose':
        if data_type == 'bub':
            conf_opts['op_affinity_graph'] = op_af_graph_bub
        else:
            assert False, "define aff graph"

    # if data_type in ['brit0' ,'brit1','brit2']:
    #     conf_opts['adjust_contrast'] = True
    #     if train_type == 'unet':
    #         conf_opts['batch_size'] = 2
    #     else:
    #         conf_opts['batch_size'] = 4

    # if data_type in ['romain']:
    #     if train_type in ['mdn','resnet_unet']:
    #         conf_opts['batch_size'] = 2
    #     elif train_type in ['unet']:
    #         conf_opts['batch_size'] = 1
    #     else:
    #         conf_opts['batch_size'] = 4
    #
    # if data_type in ['larva']:
    #     conf_opts['batch_size'] = 4
    #     conf_opts['adjust_contrast'] = True
    #     conf_opts['clahe_grid_size'] = 20
    #     if train_type in ['unet','resnet_unet','leap']:
    #         conf_opts['rescale'] = 2
    #         conf_opts['batch_size'] = 2
    #     if train_type in ['mdn']:
    #         conf_opts['batch_size'] = 4
    #         conf_opts['rescale'] = 2
    #         conf_opts['mdn_use_unet_loss'] = True
    #         # conf_opts['mdn_learning_rate'] = 0.0001
    #
    # if data_type == 'stephen':
    #     conf_opts['batch_size'] = 4

    # if data_type == 'carsen':
    #     if train_type in ['mdn','unet','resnet_unet']:
    #         conf_opts['rescale'] = 2.
    #     else:
    #         conf_opts['rescale'] = 1.
    #     conf_opts['adjust_contrast'] = True
    #     conf_opts['clahe_grid_size'] = 20
    #     if train_type in ['unet']:
    #         conf_opts['batch_size'] = 4
    #     else:
    #         conf_opts['batch_size'] = 8
    #
    # if op_af_graph is not None:
    #     conf_opts['op_affinity_graph'] = op_af_graph

    if len(conf_opts) > 0:
        conf_str = ' -conf_params'
        for k in conf_opts.keys():
            conf_str = '{} {} {} '.format(conf_str, k, conf_opts[k])
    else:
        conf_str = ''

    opt_str = ''
    for k in cmd_opts.keys():
        opt_str = '{} -{} {} '.format(opt_str, k, cmd_opts[k])

    cur_cmd = common_cmd + conf_str + opt_str + end_cmd
    cmd_name = '{}_view{}_{}_{}'.format(data_type, view, exp_name, train_type)
    if run_type == 'dry':
        print cmd_name
        print cur_cmd
        print
    elif run_type == 'submit':
        print cmd_name
        print cur_cmd
        print
        run_jobs(cmd_name, cur_cmd)
    elif run_type == 'status':
        conf = apt.create_conf(lbl_file, view, exp_name, cdir, train_type)
        check_train_status(cmd_name, conf.cachedir)
Exemple #10
0
import APT_interface as apt
import os
# Alice's dataset

name = 'alice'
val_ratio = 0.1
lbl_file = '/home/kabram/Dropbox (HHMI)/temp/multitarget_bubble_expandedbehavior_20180425_FxdErrs_OptoParams20181126_dlstripped.lbl'
nviews = 1

for view in range(nviews):
    conf = apt.create_conf(lbl_file,0,'tfds','/home/kabram/temp','mdn')
    conf.cachedir = '/home/kabram/temp/tfds_{}_view{}'.format(name,view)
    conf.valratio = val_ratio
    os.makedirs(conf.cachedir,exist_ok=True)
    apt.create_tfrecord(conf, split=True, split_file=None, use_cache=True, on_gt=False)

Exemple #11
0
common_conf['batch_size'] = 8
common_conf['maxckpt'] = 20
cache_dir = '/nrs/branson/mayank/apt_cache'
train_name = 'deepnet'

assert gt_lbl is None
all_view = []
for view in range(nviews):
    out_exp = {}
    for tndx in range(len(all_models)):
        train_type = all_models[tndx]

        out_split = None
        for split in range(n_splits):
            exp_name = 'cv_split_{}'.format(split)
            mdn_conf = apt.create_conf(lbl_file, view, exp_name, cache_dir,
                                       'mdn')
            conf = apt.create_conf(lbl_file, view, exp_name, cache_dir,
                                   train_type)

            if op_af_graph is not None:
                conf.op_affinity_graph = ast.literal_eval(
                    op_af_graph.replace('\\', ''))
            files = glob.glob(
                os.path.join(conf.cachedir, "{}-[0-9]*").format(train_name))
            files.sort(key=os.path.getmtime)
            files = [
                f for f in files if os.path.splitext(f)[1] in ['.index', '']
            ]
            aa = [int(re.search('-(\d*)', f).groups(0)[0]) for f in files]
            aa = [b - a for a, b in zip(aa[:-1], aa[1:])]
            if any([a < 0 for a in aa]):
Exemple #12
0
out_leap = [[preds,labels,[],[],[],[]]]
dd_leap = np.sqrt(np.sum((labels-preds)**2,1))
dd_leap = dd_leap.T

cache_dir = '/nrs/branson/mayank/apt_cache'
exp_name = 'apt_expt'
train_name = 'deepnet'
gt_file = os.path.join(cache_dir, rae.proj_name, 'gtdata', 'gtdata_view{}{}.tfrecords'.format(view, rae.gt_name))
H = multiResData.read_and_decode_without_session(gt_file, 32)
ex_im = np.array(H[0][0])[:, :, 0]
ex_loc = np.array(H[1][0])
our_res = pt.pickle_load('/nrs/branson/mayank/apt_cache/leap_dset/leap/view_0/apt_expt/deepnet_results.p')
our_preds = our_res[0][-1][0]
our_labels = our_res[0][-1][1]

conf = apt.create_conf(rae.lbl_file,0,'apt_expt',cache_dir,'leap')
orig_leap_models = ['/nrs/branson/mayank/apt_cache/leap_dset/leap/view_0/apt_expt/weights-045.h5',]
orig_leap = apt_expts.classify_db_all(conf,gt_file,orig_leap_models,'leap',name=train_name)

out_dict = {'leap':out_leap,'our leap':our_res[0],'leap_orig':orig_leap}
rae.plot_hist([out_dict,ex_im,ex_loc])

## mdn with and without unet
import run_apt_expts as rae
rae.setup('alice')
rae.run_mdn_no_unet()


##
import run_apt_expts as rae
rae.setup('alice')
Exemple #13
0
## Setup

import APT_interface as apt
import RNN_postprocess
import tensorflow as tf
import os
import easydict
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
tf.reset_default_graph()

exp_name = 'postprocess'
view = 0
mdn_name = 'deepnet'
lbl_file = '/groups/branson/bransonlab/apt/experiments/data/sh_trn4992_gtcomplete_cacheddata_updatedAndPpdbManuallyCopied20190402_dlstripped.lbl'

conf = apt.create_conf(lbl_file, view, exp_name,
                       '/nrs/branson/mayank/apt_cache', 'mdn')
conf.n_steps = 2

conf.rrange = 30
conf.trange = 10
conf.mdn_use_unet_loss = True
conf.dl_steps = 40000
conf.decay_steps = 20000
conf.save_step = 5000
conf.batch_size = 8
conf.normalize_img_mean = False
conf.maxckpt = 20

## Train MDN
args = easydict.EasyDict
args.skip_db = False
Exemple #14
0
url_lib.urlretrieve(gt_file_url, gt_file)

res_file_url = 'https://www.dropbox.com/s/cr702321rvv3htl/alice_view0_time.mat?dl=1'
res_file = os.path.join(tdir,'alice_view0_time.mat')
url_lib.urlretrieve(res_file_url,res_file)

cmd = '-cache {} -name {} -conf_params batch_size {} dl_steps {} op_affinity_graph {} -type {{}} {} train -use_cache '.format(tdir, exp_name, bsz, dl_steps,op_af_graph, lbl_file)

##
import h5py
R = h5py.File(res_file,'r')

for net in net_types:
    apt.main(cmd.format(net).split())

    conf = apt.create_conf(lbl_file, 0, exp_name, tdir, net)
    # if data_type == 'stephen' and train_type == 'mdn':
    #     conf.mdn_use_unet_loss = False
    if op_af_graph is not None:
        conf.op_affinity_graph = ast.literal_eval(op_af_graph.replace('\\', ''))
    files = glob.glob(os.path.join(conf.cachedir, "{}-[0-9]*").format('deepnet'))
    files.sort(key=os.path.getmtime)
    files = [f for f in files if os.path.splitext(f)[1] in ['.index', '']]
    aa = [int(re.search('-(\d*)', f).groups(0)[0]) for f in files]
    aa = [b - a for a, b in zip(aa[:-1], aa[1:])]
    if any([a < 0 for a in aa]):
        bb = int(np.where(np.array(aa) < 0)[0]) + 1
        files = files[bb:]
    files = files[-1:]
    # n_max = 10
    # if len(files)> n_max:
Exemple #15
0
def compute_peformance(args):
    H = h5py.File(args.lbl_file, 'r')
    nviews = int(apt.read_entry(H['cfg']['NumViews']))
    dir_name = args.whose

    if len(args.nets) == 0:
        all_nets = methods
    else:
        all_nets = args.nets

    all_preds = {}

    for view in range(nviews):
        db_file = os.path.join(out_dir, args.name,
                               args.gt_name) + '_view{}.tfrecords'.format(view)
        conf = apt.create_conf(args.lbl_file,
                               view,
                               name='a',
                               net_type=all_nets[0],
                               cache_dir=os.path.join(out_dir, args.name,
                                                      dir_name))
        conf.labelfile = args.gt_lbl
        if not (os.path.exists(db_file) and args.skip_gt_db):
            print('Creating GT DB file {}'.format(db_file))
            apt.create_tfrecord(conf,
                                split=False,
                                on_gt=True,
                                db_files=(db_file, ))

    for curm in all_nets:
        all_preds[curm] = []
        for view in range(nviews):
            cur_out = []
            db_file = os.path.join(
                out_dir, args.name,
                args.gt_name) + '_view{}.tfrecords'.format(view)
            if args.split_type is None:
                cachedir = os.path.join(out_dir, args.name, dir_name,
                                        '{}_view_{}'.format(curm,
                                                            view), 'full')
                conf = apt.create_conf(args.lbl_file,
                                       view,
                                       name='a',
                                       net_type=curm,
                                       cache_dir=cachedir)
                model_files, ts = get_model_files(conf, cachedir, curm)
                for mndx, m in enumerate(model_files):
                    out_file = m + '_' + args.gt_name
                    load = False
                    if curm == 'unet' or curm == 'deeplabcut':
                        mm = m + '.index'
                    else:
                        mm = m
                    if os.path.exists(out_file + '.mat') and os.path.getmtime(
                            out_file + '.mat') > os.path.getmtime(mm):
                        load = True

                    if load:
                        H = sio.loadmat(out_file)
                        pred = H['pred_locs'] - 1
                        label = H['labeled_locs'] - 1
                        gt_list = H['list'] - 1
                    else:
                        # pred, label, gt_list = apt.classify_gt_data(conf, curm, out_file, m)
                        tf_iterator = multiResData.tf_reader(
                            conf, db_file, False)
                        tf_iterator.batch_size = 1
                        read_fn = tf_iterator.next
                        pred_fn, close_fn, _ = apt.get_pred_fn(curm, conf, m)
                        pred, label, gt_list = apt.classify_db(
                            conf, read_fn, pred_fn, tf_iterator.N)
                        close_fn()
                        mat_pred_locs = pred + 1
                        mat_labeled_locs = np.array(label) + 1
                        mat_list = gt_list

                        sio.savemat(
                            out_file, {
                                'pred_locs': mat_pred_locs,
                                'labeled_locs': mat_labeled_locs,
                                'list': mat_list
                            })

                    cur_out.append(
                        [pred, label, gt_list, m, out_file, view, 0, ts[mndx]])

            else:

                for cur_split in range(nsplits):
                    cachedir = os.path.join(out_dir, args.name,
                                            '{}_view_{}'.format(curm, view),
                                            'cv_{}'.format(cur_split))
                    conf = apt.create_conf(args.lbl_file,
                                           view,
                                           name='a',
                                           net_type=curm,
                                           cache_dir=cachedir)
                    model_files, ts = get_model_files(conf, cachedir, curm)
                    db_file = os.path.join(cachedir, 'val_TF.tfrecords')
                    for mndx, m in enumerate(model_files):
                        out_file = m + '.gt_data'
                        load = False
                        if curm == 'unet' or curm == 'deeplabcut':
                            mm = m + '.index'
                        else:
                            mm = m
                        if os.path.exists(
                                out_file + '.mat') and os.path.getmtime(
                                    out_file + '.mat') > os.path.getmtime(mm):
                            load = True

                        if load:
                            H = sio.loadmat(out_file)
                            pred = H['pred_locs'] - 1
                            label = H['labeled_locs'] - 1
                            gt_list = H['list'] - 1
                        else:
                            tf_iterator = multiResData.tf_reader(
                                conf, db_file, False)
                            tf_iterator.batch_size = 1
                            read_fn = tf_iterator.next
                            pred_fn, close_fn, _ = apt.get_pred_fn(
                                curm, conf, m)
                            pred, label, gt_list = apt.classify_db(
                                conf, read_fn, pred_fn, tf_iterator.N)
                            close_fn()
                            mat_pred_locs = pred + 1
                            mat_labeled_locs = np.array(label) + 1
                            mat_list = gt_list

                            sio.savemat(
                                out_file, {
                                    'pred_locs': mat_pred_locs,
                                    'labeled_locs': mat_labeled_locs,
                                    'list': mat_list
                                })

                        cur_out.append([
                            pred, label, gt_list, m, out_file, view, cur_split,
                            ts[mndx]
                        ])

            all_preds[curm].append(cur_out)

    with open(
            os.path.join(out_dir, args.name, dir_name,
                         args.gt_name + '_results.p'), 'w') as f:
        pickle.dump(all_preds, f)
Exemple #16
0
def create_db(args):
    H = h5py.File(args.lbl_file, 'r')
    nviews = int(apt.read_entry(H['cfg']['NumViews']))
    all_nets = args.nets

    all_split_files = []
    for view in range(nviews):
        if args.split_type is not None and not args.split_type.startswith(
                'prog'):
            cachedir = os.path.join(out_dir, args.name, 'common')
            if not os.path.exists(cachedir):
                os.mkdir(cachedir)
            cachedir = os.path.join(out_dir, args.name, 'common',
                                    'splits_{}'.format(view))
            if not os.path.exists(cachedir):
                os.mkdir(cachedir)
            conf = apt.create_conf(args.lbl_file,
                                   view,
                                   args.name,
                                   cache_dir=cachedir)
            conf.splitType = args.split_type
            print("Split type is {}".format(conf.splitType))
            if args.do_split:
                train_info, val_info, split_files = apt.create_cv_split_files(
                    conf, nsplits)
            else:
                split_files = [
                    os.path.join(conf.cachedir,
                                 'cv_split_fold_{}.json'.format(ndx))
                    for ndx in range(nsplits)
                ]
            all_split_files.append(split_files)

    for curm in all_nets:
        for view in range(nviews):

            if args.split_type is None:

                cachedir = os.path.join(out_dir, args.name, 'common',
                                        '{}_view_{}'.format(curm,
                                                            view), 'full')
                conf = apt.create_conf(args.lbl_file,
                                       view,
                                       args.name,
                                       cache_dir=cachedir)
                if not args.only_check:
                    if not os.path.exists(conf.cachedir):
                        os.makedirs(conf.cachedir)
                    if curm == 'unet' or curm == 'openpose':
                        apt.create_tfrecord(conf, False)
                    elif curm == 'leap':
                        apt.create_leap_db(conf, False)
                    elif curm == 'deeplabcut':
                        apt.create_deepcut_db(conf, False)
                        create_deepcut_cfg(conf)
                    else:
                        raise ValueError('Undefined net type: {}'.format(curm))

                check_db(curm, conf)
            elif args.split_type.startswith('prog'):
                split_type = args.split_type[5:]
                all_info = get_increasing_splits(conf, split_type)

                for cur_tr in prog_thresholds:
                    cachedir = os.path.join(out_dir, args.name, 'common',
                                            '{}_view_{}'.format(curm, view),
                                            '{}'.format(cur_tr))
                    conf = apt.create_conf(args.lbl_file,
                                           view,
                                           args.name,
                                           cache_dir=cachedir)
                    split_ndx = round(len(all_info) / cur_tr)
                    cur_train = all_info[:split_ndx]
                    cur_val = all_info[split_ndx:]
                    split_file = os.path.join(cachedir, 'splitdata.json')
                    with open(split_file, 'w') as f:
                        json.dump([cur_train, cur_val], f)
                    if not args.only_check:
                        if curm == 'unet' or curm == 'openpose':
                            apt.create_tfrecord(conf, True, split_file)
                        elif curm == 'leap':
                            apt.create_leap_db(conf, True, split_file)
                        elif curm == 'deeplabcut':
                            apt.create_deepcut_db(conf, True, split_file)
                            create_deepcut_cfg(conf)
                        else:
                            raise ValueError(
                                'Undefined net type: {}'.format(curm))
                    check_db(curm, conf)

            else:

                split_files = all_split_files[view]

                for cur_split in range(nsplits):
                    conf.cachedir = os.path.join(
                        out_dir, args.name, 'common',
                        '{}_view_{}'.format(curm, view))
                    if not os.path.exists(conf.cachedir):
                        os.mkdir(conf.cachedir)
                    conf.cachedir = os.path.join(
                        out_dir, args.name, 'common',
                        '{}_view_{}'.format(curm,
                                            view), 'cv_{}'.format(cur_split))
                    if not os.path.exists(conf.cachedir):
                        os.mkdir(conf.cachedir)
                    conf.splitType = 'predefined'
                    split_file = split_files[cur_split]
                    if not args.only_check:
                        if curm == 'unet' or curm == 'openpose':
                            apt.create_tfrecord(conf, True, split_file)
                        elif curm == 'leap':
                            apt.create_leap_db(conf, True, split_file)
                        elif curm == 'deeplabcut':
                            apt.create_deepcut_db(conf, True, split_file)
                            create_deepcut_cfg(conf)
                        else:
                            raise ValueError(
                                'Undefined net type: {}'.format(curm))
                    check_db(curm, conf)

        base_dir = os.path.join(out_dir, args.name, 'common')
        their_dir = os.path.join(out_dir, args.name, 'theirs')
        our_dir = os.path.join(out_dir, args.name, 'ours')
        our_default_dir = os.path.join(out_dir, args.name, 'ours_default')
        cmd = 'cp -rs {} {}'.format(base_dir, their_dir)
        os.system(cmd)
        cmd = 'cp -rs {} {}'.format(base_dir, our_dir)
        os.system(cmd)
        cmd = 'cp -rs {} {}'.format(base_dir, our_default_dir)
        os.system(cmd)