示例#1
0
文件: nbHG.py 项目: tsizemo2/APT
def cv_train_from_mat(lbl_file,
                      cdir,
                      cv_info_file,
                      models_run,
                      view=0,
                      skip_db=False,
                      create_splits=True,
                      dorun=False,
                      run_type='status'):

    cv_info, in_info, label_info = read_cvinfo(lbl_file, cdir, cv_info_file,
                                               view)

    lbl = h5py.File(lbl_file, 'r')
    proj_name = apt.read_string(lbl['projname'])
    lbl.close()

    cvifileshort = os.path.basename(cv_info_file)
    cvifileshort = os.path.splitext(cvifileshort)[0]

    n_splits = max(cv_info) + 1

    print("{} splits, {} rows in cvi, {} rows in lbl, projname {}".format(
        n_splits, len(cv_info), len(label_info), proj_name))

    for sndx in range(n_splits):
        val_info = [l for ndx, l in enumerate(in_info) if cv_info[ndx] == sndx]
        trn_info = list(set(label_info) - set(val_info))
        cur_split = [trn_info, val_info]
        exp_name = '{:s}__split{}'.format(cvifileshort, sndx)
        split_file = os.path.join(cdir, proj_name, exp_name) + '.json'
        if not skip_db and create_splits:
            assert not os.path.exists(split_file)
            with open(split_file, 'w') as f:
                json.dump(cur_split, f)

        # create the dbs
        if not skip_db:
            for train_type in models_run:
                conf = apt.create_conf(lbl_file, view, exp_name, cdir,
                                       train_type)
                conf.splitType = 'predefined'
                if train_type == 'deeplabcut':
                    apt.create_deepcut_db(conf,
                                          split=True,
                                          split_file=split_file,
                                          use_cache=True)
                elif train_type == 'leap':
                    apt.create_leap_db(conf,
                                       split=True,
                                       split_file=split_file,
                                       use_cache=True)
                else:
                    apt.create_tfrecord(conf,
                                        split=True,
                                        split_file=split_file,
                                        use_cache=True)
        if dorun:
            for train_type in models_run:
                rapt.run_trainining(elblbubxp_name, train_type, view, run_type)
示例#2
0
import matplotlib.pyplot as plt
import apt_expts
import os
import ast
import apt_expts
import os
import pickle

os.environ['CUDA_VISIBLE_DEVICES'] = ''

gt_lbl = None
lbl_file = '/groups/branson/bransonlab/apt/experiments/data/roian_apt_dlstripped.lbl'
op_af_graph = '\(0,1\),\(0,2\),\(0,3\),\(1,2\),\(1,3\),\(2,3\)'

lbl = h5py.File(lbl_file, 'r')
proj_name = apt.read_string(lbl['projname'])
nviews = int(apt.read_entry(lbl['cfg']['NumViews']))
lbl.close()
cache_dir = '/nrs/branson/mayank/apt_cache'
all_models = ['openpose']

gpu_model = 'GeForceRTX2080Ti'
sdir = '/groups/branson/home/kabram/bransonlab/APT/deepnet/singularity_stuff'
n_splits = 3

common_conf = {}
common_conf['rrange'] = 10
common_conf['trange'] = 5
common_conf['mdn_use_unet_loss'] = True
common_conf['dl_steps'] = 100000
common_conf['decay_steps'] = 20000