Exemple #1
0
def cv_train_from_mat(lbl_file,
                      cdir,
                      cv_info_file,
                      models_run,
                      view=0,
                      skip_db=False,
                      create_splits=True,
                      dorun=False,
                      run_type='status'):

    cv_info, in_info, label_info = read_cvinfo(lbl_file, cdir, cv_info_file,
                                               view)

    lbl = h5py.File(lbl_file, 'r')
    proj_name = apt.read_string(lbl['projname'])
    lbl.close()

    cvifileshort = os.path.basename(cv_info_file)
    cvifileshort = os.path.splitext(cvifileshort)[0]

    n_splits = max(cv_info) + 1

    print("{} splits, {} rows in cvi, {} rows in lbl, projname {}".format(
        n_splits, len(cv_info), len(label_info), proj_name))

    for sndx in range(n_splits):
        val_info = [l for ndx, l in enumerate(in_info) if cv_info[ndx] == sndx]
        trn_info = list(set(label_info) - set(val_info))
        cur_split = [trn_info, val_info]
        exp_name = '{:s}__split{}'.format(cvifileshort, sndx)
        split_file = os.path.join(cdir, proj_name, exp_name) + '.json'
        if not skip_db and create_splits:
            assert not os.path.exists(split_file)
            with open(split_file, 'w') as f:
                json.dump(cur_split, f)

        # create the dbs
        if not skip_db:
            for train_type in models_run:
                conf = apt.create_conf(lbl_file, view, exp_name, cdir,
                                       train_type)
                conf.splitType = 'predefined'
                if train_type == 'deeplabcut':
                    apt.create_deepcut_db(conf,
                                          split=True,
                                          split_file=split_file,
                                          use_cache=True)
                elif train_type == 'leap':
                    apt.create_leap_db(conf,
                                       split=True,
                                       split_file=split_file,
                                       use_cache=True)
                else:
                    apt.create_tfrecord(conf,
                                        split=True,
                                        split_file=split_file,
                                        use_cache=True)
        if dorun:
            for train_type in models_run:
                rapt.run_trainining(elblbubxp_name, train_type, view, run_type)
Exemple #2
0
def train():
    import PoseUNet_resnet as PoseURes
    import tensorflow as tf

    dstr = PoseTools.datestr()
    cur_name = 'stephen_{}'.format(dstr)

    for view in range(2):
        conf = apt.create_conf(lbl_file,
                               view=view,
                               name=cur_name,
                               cache_dir=cache_dir,
                               net_type=model_type)
        update_conf(conf)
        apt.create_tfrecord(conf, False, use_cache=True)
        tf.reset_default_graph()
        self = PoseURes.PoseUMDN_resnet(conf, name='deepnet')
        self.train_data_name = 'traindata'
        self.train_umdn()
Exemple #3
0
import shutil
import h5py
import logging
reload(apt)

lbl_file = '/home/mayank/work/poseTF/data/alice/multitarget_bubble_expandedbehavior_20180425_local.lbl'

split_file = '/home/mayank/work/poseTF/cache/apt_interface/multitarget_bubble_view0/test_leap/splitdata.json'

log = logging.getLogger()  # root logger
log.setLevel(logging.ERROR)

import deepcut.train
conf = apt.create_conf(lbl_file,0,'test_openpose_delete')
conf.splitType = 'predefined'
apt.create_tfrecord(conf, True, split_file=split_file)
from poseConfig import config as args
args.skip_db = True
apt.train_openpose(conf,args)

##
import deepcut.train
import  tensorflow as tf
tf.reset_default_graph
conf.batch_size = 1
pred_fn, model_file = deepcut.train.get_pred_fn(conf)
rfn, n= deepcut.train.get_read_fn(conf,'/home/mayank/work/poseTF/cache/apt_interface/multitarget_bubble_view0/test_deepcut/val_data.p')
A = apt.classify_db(conf, rfn, pred_fn, n)

##
import socket
Exemple #4
0
op_af_graph = '\(0,1\),\(0,2\),\(2,3\),\(1,3\),\(0,4\),\(1,4\)'

lbl = h5py.File(lbl_file, 'r')
proj_name = apt.read_string(lbl['projname'])
nviews = int(apt.read_entry(lbl['cfg']['NumViews']))
lbl.close()

cache_dir = '/nrs/branson/mayank/apt_cache'

train_type = 'mdn'
exp_name = 'apt_exp'
for view in range(nviews):
    conf = apt.create_conf(gt_lbl, view, exp_name, cache_dir, train_type)
    gt_file = os.path.join(cache_dir, proj_name, 'gtdata',
                           'gtdata_view{}.tfrecords'.format(view))
    apt.create_tfrecord(conf, False, None, False, True, [gt_file])

##
import APT_interface as apt
import os
import glob
import apt_expts
import h5py

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
cache_dir = '/nrs/branson/mayank/apt_cache'
exp_name = 'apt_expt'
train_name = 'deepnet'
view = 0

lbl_file = '/groups/branson/bransonlab/apt/experiments/data/multitarget_bubble_expandedbehavior_20180425_FxdErrs_OptoParams20181126_dlstripped.lbl'
Exemple #5
0
import APT_interface as apt
import os
# Alice's dataset

name = 'alice'
val_ratio = 0.1
lbl_file = '/home/kabram/Dropbox (HHMI)/temp/multitarget_bubble_expandedbehavior_20180425_FxdErrs_OptoParams20181126_dlstripped.lbl'
nviews = 1

for view in range(nviews):
    conf = apt.create_conf(lbl_file,0,'tfds','/home/kabram/temp','mdn')
    conf.cachedir = '/home/kabram/temp/tfds_{}_view{}'.format(name,view)
    conf.valratio = val_ratio
    os.makedirs(conf.cachedir,exist_ok=True)
    apt.create_tfrecord(conf, split=True, split_file=None, use_cache=True, on_gt=False)

Exemple #6
0
def compute_peformance(args):
    H = h5py.File(args.lbl_file, 'r')
    nviews = int(apt.read_entry(H['cfg']['NumViews']))
    dir_name = args.whose

    if len(args.nets) == 0:
        all_nets = methods
    else:
        all_nets = args.nets

    all_preds = {}

    for view in range(nviews):
        db_file = os.path.join(out_dir, args.name,
                               args.gt_name) + '_view{}.tfrecords'.format(view)
        conf = apt.create_conf(args.lbl_file,
                               view,
                               name='a',
                               net_type=all_nets[0],
                               cache_dir=os.path.join(out_dir, args.name,
                                                      dir_name))
        conf.labelfile = args.gt_lbl
        if not (os.path.exists(db_file) and args.skip_gt_db):
            print('Creating GT DB file {}'.format(db_file))
            apt.create_tfrecord(conf,
                                split=False,
                                on_gt=True,
                                db_files=(db_file, ))

    for curm in all_nets:
        all_preds[curm] = []
        for view in range(nviews):
            cur_out = []
            db_file = os.path.join(
                out_dir, args.name,
                args.gt_name) + '_view{}.tfrecords'.format(view)
            if args.split_type is None:
                cachedir = os.path.join(out_dir, args.name, dir_name,
                                        '{}_view_{}'.format(curm,
                                                            view), 'full')
                conf = apt.create_conf(args.lbl_file,
                                       view,
                                       name='a',
                                       net_type=curm,
                                       cache_dir=cachedir)
                model_files, ts = get_model_files(conf, cachedir, curm)
                for mndx, m in enumerate(model_files):
                    out_file = m + '_' + args.gt_name
                    load = False
                    if curm == 'unet' or curm == 'deeplabcut':
                        mm = m + '.index'
                    else:
                        mm = m
                    if os.path.exists(out_file + '.mat') and os.path.getmtime(
                            out_file + '.mat') > os.path.getmtime(mm):
                        load = True

                    if load:
                        H = sio.loadmat(out_file)
                        pred = H['pred_locs'] - 1
                        label = H['labeled_locs'] - 1
                        gt_list = H['list'] - 1
                    else:
                        # pred, label, gt_list = apt.classify_gt_data(conf, curm, out_file, m)
                        tf_iterator = multiResData.tf_reader(
                            conf, db_file, False)
                        tf_iterator.batch_size = 1
                        read_fn = tf_iterator.next
                        pred_fn, close_fn, _ = apt.get_pred_fn(curm, conf, m)
                        pred, label, gt_list = apt.classify_db(
                            conf, read_fn, pred_fn, tf_iterator.N)
                        close_fn()
                        mat_pred_locs = pred + 1
                        mat_labeled_locs = np.array(label) + 1
                        mat_list = gt_list

                        sio.savemat(
                            out_file, {
                                'pred_locs': mat_pred_locs,
                                'labeled_locs': mat_labeled_locs,
                                'list': mat_list
                            })

                    cur_out.append(
                        [pred, label, gt_list, m, out_file, view, 0, ts[mndx]])

            else:

                for cur_split in range(nsplits):
                    cachedir = os.path.join(out_dir, args.name,
                                            '{}_view_{}'.format(curm, view),
                                            'cv_{}'.format(cur_split))
                    conf = apt.create_conf(args.lbl_file,
                                           view,
                                           name='a',
                                           net_type=curm,
                                           cache_dir=cachedir)
                    model_files, ts = get_model_files(conf, cachedir, curm)
                    db_file = os.path.join(cachedir, 'val_TF.tfrecords')
                    for mndx, m in enumerate(model_files):
                        out_file = m + '.gt_data'
                        load = False
                        if curm == 'unet' or curm == 'deeplabcut':
                            mm = m + '.index'
                        else:
                            mm = m
                        if os.path.exists(
                                out_file + '.mat') and os.path.getmtime(
                                    out_file + '.mat') > os.path.getmtime(mm):
                            load = True

                        if load:
                            H = sio.loadmat(out_file)
                            pred = H['pred_locs'] - 1
                            label = H['labeled_locs'] - 1
                            gt_list = H['list'] - 1
                        else:
                            tf_iterator = multiResData.tf_reader(
                                conf, db_file, False)
                            tf_iterator.batch_size = 1
                            read_fn = tf_iterator.next
                            pred_fn, close_fn, _ = apt.get_pred_fn(
                                curm, conf, m)
                            pred, label, gt_list = apt.classify_db(
                                conf, read_fn, pred_fn, tf_iterator.N)
                            close_fn()
                            mat_pred_locs = pred + 1
                            mat_labeled_locs = np.array(label) + 1
                            mat_list = gt_list

                            sio.savemat(
                                out_file, {
                                    'pred_locs': mat_pred_locs,
                                    'labeled_locs': mat_labeled_locs,
                                    'list': mat_list
                                })

                        cur_out.append([
                            pred, label, gt_list, m, out_file, view, cur_split,
                            ts[mndx]
                        ])

            all_preds[curm].append(cur_out)

    with open(
            os.path.join(out_dir, args.name, dir_name,
                         args.gt_name + '_results.p'), 'w') as f:
        pickle.dump(all_preds, f)
Exemple #7
0
def create_db(args):
    H = h5py.File(args.lbl_file, 'r')
    nviews = int(apt.read_entry(H['cfg']['NumViews']))
    all_nets = args.nets

    all_split_files = []
    for view in range(nviews):
        if args.split_type is not None and not args.split_type.startswith(
                'prog'):
            cachedir = os.path.join(out_dir, args.name, 'common')
            if not os.path.exists(cachedir):
                os.mkdir(cachedir)
            cachedir = os.path.join(out_dir, args.name, 'common',
                                    'splits_{}'.format(view))
            if not os.path.exists(cachedir):
                os.mkdir(cachedir)
            conf = apt.create_conf(args.lbl_file,
                                   view,
                                   args.name,
                                   cache_dir=cachedir)
            conf.splitType = args.split_type
            print("Split type is {}".format(conf.splitType))
            if args.do_split:
                train_info, val_info, split_files = apt.create_cv_split_files(
                    conf, nsplits)
            else:
                split_files = [
                    os.path.join(conf.cachedir,
                                 'cv_split_fold_{}.json'.format(ndx))
                    for ndx in range(nsplits)
                ]
            all_split_files.append(split_files)

    for curm in all_nets:
        for view in range(nviews):

            if args.split_type is None:

                cachedir = os.path.join(out_dir, args.name, 'common',
                                        '{}_view_{}'.format(curm,
                                                            view), 'full')
                conf = apt.create_conf(args.lbl_file,
                                       view,
                                       args.name,
                                       cache_dir=cachedir)
                if not args.only_check:
                    if not os.path.exists(conf.cachedir):
                        os.makedirs(conf.cachedir)
                    if curm == 'unet' or curm == 'openpose':
                        apt.create_tfrecord(conf, False)
                    elif curm == 'leap':
                        apt.create_leap_db(conf, False)
                    elif curm == 'deeplabcut':
                        apt.create_deepcut_db(conf, False)
                        create_deepcut_cfg(conf)
                    else:
                        raise ValueError('Undefined net type: {}'.format(curm))

                check_db(curm, conf)
            elif args.split_type.startswith('prog'):
                split_type = args.split_type[5:]
                all_info = get_increasing_splits(conf, split_type)

                for cur_tr in prog_thresholds:
                    cachedir = os.path.join(out_dir, args.name, 'common',
                                            '{}_view_{}'.format(curm, view),
                                            '{}'.format(cur_tr))
                    conf = apt.create_conf(args.lbl_file,
                                           view,
                                           args.name,
                                           cache_dir=cachedir)
                    split_ndx = round(len(all_info) / cur_tr)
                    cur_train = all_info[:split_ndx]
                    cur_val = all_info[split_ndx:]
                    split_file = os.path.join(cachedir, 'splitdata.json')
                    with open(split_file, 'w') as f:
                        json.dump([cur_train, cur_val], f)
                    if not args.only_check:
                        if curm == 'unet' or curm == 'openpose':
                            apt.create_tfrecord(conf, True, split_file)
                        elif curm == 'leap':
                            apt.create_leap_db(conf, True, split_file)
                        elif curm == 'deeplabcut':
                            apt.create_deepcut_db(conf, True, split_file)
                            create_deepcut_cfg(conf)
                        else:
                            raise ValueError(
                                'Undefined net type: {}'.format(curm))
                    check_db(curm, conf)

            else:

                split_files = all_split_files[view]

                for cur_split in range(nsplits):
                    conf.cachedir = os.path.join(
                        out_dir, args.name, 'common',
                        '{}_view_{}'.format(curm, view))
                    if not os.path.exists(conf.cachedir):
                        os.mkdir(conf.cachedir)
                    conf.cachedir = os.path.join(
                        out_dir, args.name, 'common',
                        '{}_view_{}'.format(curm,
                                            view), 'cv_{}'.format(cur_split))
                    if not os.path.exists(conf.cachedir):
                        os.mkdir(conf.cachedir)
                    conf.splitType = 'predefined'
                    split_file = split_files[cur_split]
                    if not args.only_check:
                        if curm == 'unet' or curm == 'openpose':
                            apt.create_tfrecord(conf, True, split_file)
                        elif curm == 'leap':
                            apt.create_leap_db(conf, True, split_file)
                        elif curm == 'deeplabcut':
                            apt.create_deepcut_db(conf, True, split_file)
                            create_deepcut_cfg(conf)
                        else:
                            raise ValueError(
                                'Undefined net type: {}'.format(curm))
                    check_db(curm, conf)

        base_dir = os.path.join(out_dir, args.name, 'common')
        their_dir = os.path.join(out_dir, args.name, 'theirs')
        our_dir = os.path.join(out_dir, args.name, 'ours')
        our_default_dir = os.path.join(out_dir, args.name, 'ours_default')
        cmd = 'cp -rs {} {}'.format(base_dir, their_dir)
        os.system(cmd)
        cmd = 'cp -rs {} {}'.format(base_dir, our_dir)
        os.system(cmd)
        cmd = 'cp -rs {} {}'.format(base_dir, our_default_dir)
        os.system(cmd)