コード例 #1
0
def classify_db_all(conf,
                    db_file,
                    model_files,
                    model_type,
                    name='deepnet',
                    distort=False,
                    return_hm=False,
                    hm_dec=1,
                    hm_floor=0.1,
                    hm_nclustermax=1):
    cur_out = []
    extra_str = ''
    if model_type not in ['leap', 'openpose']:
        extra_str = '.index'
    # else:
    #     extra_str = '.h5'
    ts = [os.path.getmtime(f + extra_str) for f in model_files]

    for mndx, m in enumerate(model_files):
        # pred, label, gt_list = apt.classify_gt_data(conf, curm, out_file, m)
        tf_iterator = multiResData.tf_reader(conf, db_file, False)
        tf_iterator.batch_size = 1
        read_fn = tf_iterator.next
        pred_fn, close_fn, _ = apt.get_pred_fn(model_type,
                                               conf,
                                               m,
                                               name=name,
                                               distort=distort)
        ret_list = apt.classify_db(conf,
                                   read_fn,
                                   pred_fn,
                                   tf_iterator.N,
                                   return_hm=return_hm,
                                   hm_dec=hm_dec,
                                   hm_floor=hm_floor,
                                   hm_nclustermax=hm_nclustermax)
        pred, label, gt_list = ret_list[:3]
        if model_type == 'mdn':
            extra_stuff = ret_list[3:]
        else:
            extra_stuff = 0
        close_fn()
        gt_list = np.array(gt_list)
        cur_out.append([pred, label, gt_list, m, extra_stuff, ts[mndx]])

    return cur_out
コード例 #2
0
ファイル: trackStephenHead_KB.py プロジェクト: tsizemo2/APT
def main(argv):

    parser = argparse.ArgumentParser()
    parser.add_argument("-s",
                        dest="sfilename",
                        help="text file with list of side view videos",
                        required=True)
    parser.add_argument(
        "-f",
        dest="ffilename",
        help=
        "text file with list of front view videos. The list of side view videos and front view videos should match up",
        required=True)
    parser.add_argument(
        "-d",
        dest="dltfilename",
        help=
        "text file with list of DLTs, one per fly as 'flynum,/path/to/dltfile'",
        required=True)
    parser.add_argument(
        "-body_lbl",
        dest="bodylabelfilename",
        help=
        "text file with list of body-label files, one per fly as 'flynum,/path/to/body_label.lbl'",
        default=bodylblfile)
    parser.add_argument("-net",
                        dest="net_name",
                        help="Name of the net to use for tracking",
                        default=default_net_name)
    parser.add_argument(
        "-o",
        dest="outdir",
        help="temporary output directory to store intermediate computations",
        required=True)
    parser.add_argument("-r",
                        dest="redo",
                        help="if specified will recompute everything",
                        action="store_true")
    parser.add_argument("-rt",
                        dest="redo_tracking",
                        help="if specified will only recompute tracking",
                        action="store_true")
    parser.add_argument("-gpu",
                        dest='gpunum',
                        type=int,
                        help="GPU to use [optional]")
    parser.add_argument("-makemovie",
                        dest='makemovie',
                        help="if specified will make results movie",
                        action="store_true")
    parser.add_argument(
        "-trackerpath",
        dest='trackerpath',
        help=
        "Absolute path to the compiled MATLAB tracker script run_compute3Dfrom2D.sh",
        default=defaulttrackerpath)
    parser.add_argument("-mcrpath",
                        dest='mcrpath',
                        help="Absolute path to MCR",
                        default=defaultmcrpath)
    parser.add_argument(
        "-ncores",
        dest="ncores",
        help="Number of cores to assign to each MATLAB tracker job",
        type=int,
        default=1)

    group = parser.add_mutually_exclusive_group()
    group.add_argument(
        "-only_detect",
        dest='detect',
        action="store_true",
        help="Do only the detection part of tracking which requires GPU")
    group.add_argument(
        "-only_track",
        dest='track',
        action="store_true",
        help="Do only the tracking part of the tracking which requires MATLAB")

    args = parser.parse_args(argv)
    if args.redo is None:
        args.redo = False
    if args.redo_tracking is None:
        args.redo_tracking = False

    if args.detect is False and args.track is False:
        args.detect = True
        args.track = True

    args.outdir = os.path.abspath(args.outdir)

    with open(args.sfilename, "r") as text_file:
        smovies = text_file.readlines()
    smovies = [x.rstrip() for x in smovies]
    with open(args.ffilename, "r") as text_file:
        fmovies = text_file.readlines()
    fmovies = [x.rstrip() for x in fmovies]

    print(smovies)
    print(fmovies)
    print(len(smovies))
    print(len(fmovies))

    if len(smovies) != len(fmovies):
        print("Side and front movies must match")
        raise exit(0)

    if args.track:
        # read in dltfile
        dltdict = {}
        f = open(args.dltfilename, 'r')
        for l in f:
            lparts = l.split(',')
            if len(lparts) != 2:
                print("Error splitting dlt file line %s into two parts" % l)
                raise exit(0)
            dltdict[float(lparts[0])] = lparts[1].strip()
        f.close()

        # compiled matlab command
        matscript = args.trackerpath + " " + args.mcrpath

    if args.detect:
        import numpy as np
        import tensorflow as tf
        from scipy import io
        from cvc import cvc
        import localSetup
        import PoseTools
        import multiResData
        import cv2
        import PoseUNet

        for ff in smovies + fmovies:
            if not os.path.isfile(ff):
                print("Movie %s not found" % (ff))
                raise exit(0)
        if args.gpunum is not None:
            os.environ['CUDA_VISIBLE_DEVICES'] = '0'

        bodydict = {}
        f = open(args.bodylabelfilename, 'r')
        for l in f:
            lparts = l.split(',')
            if len(lparts) != 2:
                print(
                    "Error splitting body label file line %s into two parts" %
                    l)
                raise exit(0)
            bodydict[int(lparts[0])] = lparts[1].strip()
        f.close()

    for view in range(2):  # 0 for front and 1 for side
        if args.detect:
            tf.reset_default_graph()
        conf = apt.create_conf(lbl_file,
                               view=view,
                               name=name,
                               cache_dir=cache_dir,
                               net_type=model_type)
        update_conf(conf)
        if view == 0:
            # from stephenHeadConfig import sideconf as conf
            extrastr = '_side'
            valmovies = smovies
        else:
            # For FRONT
            # from stephenHeadConfig import conf as conf
            extrastr = '_front'
            valmovies = fmovies

        if args.detect:
            for try_num in range(4):
                try:
                    tf.reset_default_graph()
                    # pred_fn,close_fn,model_file = apt.get_unet_pred_fn(conf)
                    pred_fn, close_fn, model_file = apt.get_pred_fn(
                        model_type=model_type, conf=conf)
                    # self = PoseUNet.PoseUNet(conf, args.net_name)
                    # sess = self.init_net_meta(1)
                    break
                except ValueError:
                    print('Loading the net failed, retrying')
                    if try_num is 3:
                        raise ValueError(
                            'Couldnt load the network after 4 tries')

        for ndx in range(len(valmovies)):
            mname, _ = os.path.splitext(os.path.basename(valmovies[ndx]))
            oname = re.sub('!', '__', getexpname(valmovies[ndx]))
            pname = os.path.join(args.outdir, oname + extrastr)

            print(oname)

            # detect
            if args.detect and os.path.isfile(valmovies[ndx]) and \
               (args.redo or not os.path.isfile(pname + '.mat')):

                cap = cv2.VideoCapture(valmovies[ndx])
                height = int(cap.get(cvc.FRAME_HEIGHT))
                width = int(cap.get(cvc.FRAME_WIDTH))
                cap.release()
                try:
                    dirname = os.path.normpath(valmovies[ndx])
                    dir_parts = dirname.split(os.sep)
                    aa = re.search('fly_*(\d+)', dir_parts[-3])
                    flynum = int(aa.groups()[0])
                except AttributeError:
                    print('Could not find the fly number from movie name')
                    print('{} isnt in standard format'.format(smovies[ndx]))
                    continue
                crop_loc_all = get_crop_locs(bodydict[flynum], view, height,
                                             width)  # return x first
                try:
                    predLocs, predScores, pred_ulocs, pred_conf = classify_movie(
                        valmovies[ndx], pred_fn, conf, crop_loc_all)
                    # predList = self.classify_movie(valmovies[ndx], sess, flipud=False)
                except KeyError:
                    continue


#                 orig_crop_loc = [crop_loc_all[i]-1 for i in (2,0)] # y first
#                 # rescale = conf.rescale
#                 # crop_loc = [int(x/rescale) for x in orig_crop_loc]
#                 # end_pad = [int((height-conf.imsz[0])/rescale)-crop_loc[0],int((width-conf.imsz[1])/rescale)-crop_loc[1]]
# #                crop_loc = [old_div(x,4) for x in orig_crop_loc]
# #                end_pad = [old_div(height,4)-crop_loc[0]-old_div(conf.imsz[0],4),old_div(width,4)-crop_loc[1]-old_div(conf.imsz[1],4)]
# #                 pp = [(0,0),(crop_loc[0],end_pad[0]),(crop_loc[1],end_pad[1]),(0,0)]
# #                 predScores = np.pad(predScores,pp,mode='constant',constant_values=-1.)
#
#                 predLocs[:,:,0] += orig_crop_loc[1] # x
#                 predLocs[:,:,1] += orig_crop_loc[0] # y

                hdf5storage.savemat(pname + '.mat', {
                    'locs': predLocs,
                    'scores': predScores,
                    'expname': valmovies[ndx],
                    'crop_loc': crop_loc_all,
                    'model_file': model_file,
                    'ulocs': pred_ulocs,
                    'pred_conf': pred_conf
                },
                                    appendmat=False,
                                    truncate_existing=True,
                                    gzip_compression_level=0)
                del predScores, predLocs

                print('Detecting:%s' % oname)

            # track
            if args.track and view == 1:

                oname_side = re.sub('!', '__', getexpname(smovies[ndx]))
                oname_front = re.sub('!', '__', getexpname(fmovies[ndx]))
                pname_side = os.path.join(args.outdir,
                                          oname_side + '_side.mat')
                pname_front = os.path.join(args.outdir,
                                           oname_front + '_front.mat')
                # 3d trajectories
                basename_front, _ = os.path.splitext(fmovies[ndx])
                basename_side, _ = os.path.splitext(smovies[ndx])
                savefile = basename_side + '_3Dres.mat'
                #savefile = os.path.join(args.outdir , oname_side + '_3Dres.mat')
                trkfile_front = basename_front + '.trk'
                trkfile_side = basename_side + '.trk'

                redo_tracking = args.redo or args.redo_tracking
                if os.path.isfile(savefile) and os.path.isfile(trkfile_front) and \
                   os.path.isfile(trkfile_side) and not redo_tracking:
                    print("%s, %s, and %s exist, skipping tracking" %
                          (savefile, trkfile_front, trkfile_side))
                    continue

                try:
                    dirname = os.path.normpath(smovies[ndx])
                    dir_parts = dirname.split(os.sep)
                    aa = re.search('fly_*(\d+)', dir_parts[-3])
                    flynum = int(aa.groups()[0])
                except AttributeError:
                    print('Could not find the fly number from movie name')
                    print('{} isnt in standard format'.format(smovies[ndx]))
                    continue
                #print "Parsed fly number as %d"%flynum
                kinematfile = os.path.abspath(dltdict[flynum])

                jobid = oname_side

                scriptfile = os.path.join(args.outdir, jobid + '_track.sh')
                logfile = os.path.join(args.outdir, jobid + '_track.log')
                errfile = os.path.join(args.outdir, jobid + '_track.err')

                #print "matscript = " + matscript
                #print "pname_front = " + pname_front
                #print "pname_side = " + pname_side
                #print "kinematfile = " + kinematfile

                # make script to be qsubbed
                scriptf = open(scriptfile, 'w')
                scriptf.write('if [ -d %s ]\n' % args.outdir)
                scriptf.write('  then export MCR_CACHE_ROOT=%s/mcrcache%s\n' %
                              (args.outdir, jobid))
                scriptf.write('fi\n')
                scriptf.write('%s "%s" "%s" "%s" "%s" "%s" "%s"\n' %
                              (matscript, savefile, pname_front, pname_side,
                               kinematfile, trkfile_front, trkfile_side))
                scriptf.write('chmod g+w {}\n'.format(savefile))
                scriptf.write('chmod g+w {}\n'.format(trkfile_front))
                scriptf.write('chmod g+w {}\n'.format(trkfile_side))
                scriptf.close()
                os.chmod(
                    scriptfile, stat.S_IRUSR | stat.S_IRGRP | stat.S_IWUSR
                    | stat.S_IWGRP | stat.S_IXUSR | stat.S_IXGRP)

                #                cmd = "ssh login1 'source /etc/profile; qsub -pe batch %d -N %s -j y -b y -o '%s' -cwd '\"%s\"''"%(args.ncores,jobid,logfile,scriptfile)
                cmd = "ssh 10.36.11.34 'source /etc/profile; bsub -n %d -J %s -oo '%s' -eo '%s' -cwd . '\"%s\"''" % (
                    args.ncores, jobid, logfile, errfile, scriptfile)
                print(cmd)
                call(cmd, shell=True)
コード例 #3
0
    files = files[bb:]
n_max = 6
if len(files) > n_max:
    gg = len(files)
    sel = np.linspace(0, len(files) - 1, n_max).astype('int')
    files = [files[s] for s in sel]

out_file = os.path.join(conf.cachedir, train_name + '_results.p')
afiles = [f.replace('.index', '') for f in files]

for m in afiles[-1:]:
    tf_iterator = multiResData.tf_reader(conf, gt_file, False)
    tf_iterator.batch_size = 1
    read_fn = tf_iterator.next
    pred_fn, close_fn, _ = apt.get_pred_fn(train_type,
                                           conf,
                                           m,
                                           name=train_name)
    bsize = conf.batch_size
    all_f = np.zeros((bsize, ) + conf.imsz + (conf.img_dim, ))
    n = tf_iterator.N
    pred_locs = np.zeros([n, conf.n_classes, 2])
    unet_locs = np.zeros([n, conf.n_classes, 2])
    mdn_locs = np.zeros([n, conf.n_classes, 2])
    n_batches = int(math.ceil(float(n) / bsize))
    labeled_locs = np.zeros([n, conf.n_classes, 2])
    all_ims = np.zeros([n, conf.imsz[0], conf.imsz[1], conf.img_dim])

    info = []
    for cur_b in range(n_batches):
        cur_start = cur_b * bsize
        ppe = min(n - cur_start, bsize)
コード例 #4
0
def compute_peformance(args):
    H = h5py.File(args.lbl_file, 'r')
    nviews = int(apt.read_entry(H['cfg']['NumViews']))
    dir_name = args.whose

    if len(args.nets) == 0:
        all_nets = methods
    else:
        all_nets = args.nets

    all_preds = {}

    for view in range(nviews):
        db_file = os.path.join(out_dir, args.name,
                               args.gt_name) + '_view{}.tfrecords'.format(view)
        conf = apt.create_conf(args.lbl_file,
                               view,
                               name='a',
                               net_type=all_nets[0],
                               cache_dir=os.path.join(out_dir, args.name,
                                                      dir_name))
        conf.labelfile = args.gt_lbl
        if not (os.path.exists(db_file) and args.skip_gt_db):
            print('Creating GT DB file {}'.format(db_file))
            apt.create_tfrecord(conf,
                                split=False,
                                on_gt=True,
                                db_files=(db_file, ))

    for curm in all_nets:
        all_preds[curm] = []
        for view in range(nviews):
            cur_out = []
            db_file = os.path.join(
                out_dir, args.name,
                args.gt_name) + '_view{}.tfrecords'.format(view)
            if args.split_type is None:
                cachedir = os.path.join(out_dir, args.name, dir_name,
                                        '{}_view_{}'.format(curm,
                                                            view), 'full')
                conf = apt.create_conf(args.lbl_file,
                                       view,
                                       name='a',
                                       net_type=curm,
                                       cache_dir=cachedir)
                model_files, ts = get_model_files(conf, cachedir, curm)
                for mndx, m in enumerate(model_files):
                    out_file = m + '_' + args.gt_name
                    load = False
                    if curm == 'unet' or curm == 'deeplabcut':
                        mm = m + '.index'
                    else:
                        mm = m
                    if os.path.exists(out_file + '.mat') and os.path.getmtime(
                            out_file + '.mat') > os.path.getmtime(mm):
                        load = True

                    if load:
                        H = sio.loadmat(out_file)
                        pred = H['pred_locs'] - 1
                        label = H['labeled_locs'] - 1
                        gt_list = H['list'] - 1
                    else:
                        # pred, label, gt_list = apt.classify_gt_data(conf, curm, out_file, m)
                        tf_iterator = multiResData.tf_reader(
                            conf, db_file, False)
                        tf_iterator.batch_size = 1
                        read_fn = tf_iterator.next
                        pred_fn, close_fn, _ = apt.get_pred_fn(curm, conf, m)
                        pred, label, gt_list = apt.classify_db(
                            conf, read_fn, pred_fn, tf_iterator.N)
                        close_fn()
                        mat_pred_locs = pred + 1
                        mat_labeled_locs = np.array(label) + 1
                        mat_list = gt_list

                        sio.savemat(
                            out_file, {
                                'pred_locs': mat_pred_locs,
                                'labeled_locs': mat_labeled_locs,
                                'list': mat_list
                            })

                    cur_out.append(
                        [pred, label, gt_list, m, out_file, view, 0, ts[mndx]])

            else:

                for cur_split in range(nsplits):
                    cachedir = os.path.join(out_dir, args.name,
                                            '{}_view_{}'.format(curm, view),
                                            'cv_{}'.format(cur_split))
                    conf = apt.create_conf(args.lbl_file,
                                           view,
                                           name='a',
                                           net_type=curm,
                                           cache_dir=cachedir)
                    model_files, ts = get_model_files(conf, cachedir, curm)
                    db_file = os.path.join(cachedir, 'val_TF.tfrecords')
                    for mndx, m in enumerate(model_files):
                        out_file = m + '.gt_data'
                        load = False
                        if curm == 'unet' or curm == 'deeplabcut':
                            mm = m + '.index'
                        else:
                            mm = m
                        if os.path.exists(
                                out_file + '.mat') and os.path.getmtime(
                                    out_file + '.mat') > os.path.getmtime(mm):
                            load = True

                        if load:
                            H = sio.loadmat(out_file)
                            pred = H['pred_locs'] - 1
                            label = H['labeled_locs'] - 1
                            gt_list = H['list'] - 1
                        else:
                            tf_iterator = multiResData.tf_reader(
                                conf, db_file, False)
                            tf_iterator.batch_size = 1
                            read_fn = tf_iterator.next
                            pred_fn, close_fn, _ = apt.get_pred_fn(
                                curm, conf, m)
                            pred, label, gt_list = apt.classify_db(
                                conf, read_fn, pred_fn, tf_iterator.N)
                            close_fn()
                            mat_pred_locs = pred + 1
                            mat_labeled_locs = np.array(label) + 1
                            mat_list = gt_list

                            sio.savemat(
                                out_file, {
                                    'pred_locs': mat_pred_locs,
                                    'labeled_locs': mat_labeled_locs,
                                    'list': mat_list
                                })

                        cur_out.append([
                            pred, label, gt_list, m, out_file, view, cur_split,
                            ts[mndx]
                        ])

            all_preds[curm].append(cur_out)

    with open(
            os.path.join(out_dir, args.name, dir_name,
                         args.gt_name + '_results.p'), 'w') as f:
        pickle.dump(all_preds, f)