Esempio n. 1
0
File: nbHG.py Progetto: tsizemo2/APT
def cv_train_from_mat(lbl_file,
                      cdir,
                      cv_info_file,
                      models_run,
                      view=0,
                      skip_db=False,
                      create_splits=True,
                      dorun=False,
                      run_type='status'):

    cv_info, in_info, label_info = read_cvinfo(lbl_file, cdir, cv_info_file,
                                               view)

    lbl = h5py.File(lbl_file, 'r')
    proj_name = apt.read_string(lbl['projname'])
    lbl.close()

    cvifileshort = os.path.basename(cv_info_file)
    cvifileshort = os.path.splitext(cvifileshort)[0]

    n_splits = max(cv_info) + 1

    print("{} splits, {} rows in cvi, {} rows in lbl, projname {}".format(
        n_splits, len(cv_info), len(label_info), proj_name))

    for sndx in range(n_splits):
        val_info = [l for ndx, l in enumerate(in_info) if cv_info[ndx] == sndx]
        trn_info = list(set(label_info) - set(val_info))
        cur_split = [trn_info, val_info]
        exp_name = '{:s}__split{}'.format(cvifileshort, sndx)
        split_file = os.path.join(cdir, proj_name, exp_name) + '.json'
        if not skip_db and create_splits:
            assert not os.path.exists(split_file)
            with open(split_file, 'w') as f:
                json.dump(cur_split, f)

        # create the dbs
        if not skip_db:
            for train_type in models_run:
                conf = apt.create_conf(lbl_file, view, exp_name, cdir,
                                       train_type)
                conf.splitType = 'predefined'
                if train_type == 'deeplabcut':
                    apt.create_deepcut_db(conf,
                                          split=True,
                                          split_file=split_file,
                                          use_cache=True)
                elif train_type == 'leap':
                    apt.create_leap_db(conf,
                                       split=True,
                                       split_file=split_file,
                                       use_cache=True)
                else:
                    apt.create_tfrecord(conf,
                                        split=True,
                                        split_file=split_file,
                                        use_cache=True)
        if dorun:
            for train_type in models_run:
                rapt.run_trainining(elblbubxp_name, train_type, view, run_type)
Esempio n. 2
0
def create_db(args):
    H = h5py.File(args.lbl_file, 'r')
    nviews = int(apt.read_entry(H['cfg']['NumViews']))
    all_nets = args.nets

    all_split_files = []
    for view in range(nviews):
        if args.split_type is not None and not args.split_type.startswith(
                'prog'):
            cachedir = os.path.join(out_dir, args.name, 'common')
            if not os.path.exists(cachedir):
                os.mkdir(cachedir)
            cachedir = os.path.join(out_dir, args.name, 'common',
                                    'splits_{}'.format(view))
            if not os.path.exists(cachedir):
                os.mkdir(cachedir)
            conf = apt.create_conf(args.lbl_file,
                                   view,
                                   args.name,
                                   cache_dir=cachedir)
            conf.splitType = args.split_type
            print("Split type is {}".format(conf.splitType))
            if args.do_split:
                train_info, val_info, split_files = apt.create_cv_split_files(
                    conf, nsplits)
            else:
                split_files = [
                    os.path.join(conf.cachedir,
                                 'cv_split_fold_{}.json'.format(ndx))
                    for ndx in range(nsplits)
                ]
            all_split_files.append(split_files)

    for curm in all_nets:
        for view in range(nviews):

            if args.split_type is None:

                cachedir = os.path.join(out_dir, args.name, 'common',
                                        '{}_view_{}'.format(curm,
                                                            view), 'full')
                conf = apt.create_conf(args.lbl_file,
                                       view,
                                       args.name,
                                       cache_dir=cachedir)
                if not args.only_check:
                    if not os.path.exists(conf.cachedir):
                        os.makedirs(conf.cachedir)
                    if curm == 'unet' or curm == 'openpose':
                        apt.create_tfrecord(conf, False)
                    elif curm == 'leap':
                        apt.create_leap_db(conf, False)
                    elif curm == 'deeplabcut':
                        apt.create_deepcut_db(conf, False)
                        create_deepcut_cfg(conf)
                    else:
                        raise ValueError('Undefined net type: {}'.format(curm))

                check_db(curm, conf)
            elif args.split_type.startswith('prog'):
                split_type = args.split_type[5:]
                all_info = get_increasing_splits(conf, split_type)

                for cur_tr in prog_thresholds:
                    cachedir = os.path.join(out_dir, args.name, 'common',
                                            '{}_view_{}'.format(curm, view),
                                            '{}'.format(cur_tr))
                    conf = apt.create_conf(args.lbl_file,
                                           view,
                                           args.name,
                                           cache_dir=cachedir)
                    split_ndx = round(len(all_info) / cur_tr)
                    cur_train = all_info[:split_ndx]
                    cur_val = all_info[split_ndx:]
                    split_file = os.path.join(cachedir, 'splitdata.json')
                    with open(split_file, 'w') as f:
                        json.dump([cur_train, cur_val], f)
                    if not args.only_check:
                        if curm == 'unet' or curm == 'openpose':
                            apt.create_tfrecord(conf, True, split_file)
                        elif curm == 'leap':
                            apt.create_leap_db(conf, True, split_file)
                        elif curm == 'deeplabcut':
                            apt.create_deepcut_db(conf, True, split_file)
                            create_deepcut_cfg(conf)
                        else:
                            raise ValueError(
                                'Undefined net type: {}'.format(curm))
                    check_db(curm, conf)

            else:

                split_files = all_split_files[view]

                for cur_split in range(nsplits):
                    conf.cachedir = os.path.join(
                        out_dir, args.name, 'common',
                        '{}_view_{}'.format(curm, view))
                    if not os.path.exists(conf.cachedir):
                        os.mkdir(conf.cachedir)
                    conf.cachedir = os.path.join(
                        out_dir, args.name, 'common',
                        '{}_view_{}'.format(curm,
                                            view), 'cv_{}'.format(cur_split))
                    if not os.path.exists(conf.cachedir):
                        os.mkdir(conf.cachedir)
                    conf.splitType = 'predefined'
                    split_file = split_files[cur_split]
                    if not args.only_check:
                        if curm == 'unet' or curm == 'openpose':
                            apt.create_tfrecord(conf, True, split_file)
                        elif curm == 'leap':
                            apt.create_leap_db(conf, True, split_file)
                        elif curm == 'deeplabcut':
                            apt.create_deepcut_db(conf, True, split_file)
                            create_deepcut_cfg(conf)
                        else:
                            raise ValueError(
                                'Undefined net type: {}'.format(curm))
                    check_db(curm, conf)

        base_dir = os.path.join(out_dir, args.name, 'common')
        their_dir = os.path.join(out_dir, args.name, 'theirs')
        our_dir = os.path.join(out_dir, args.name, 'ours')
        our_default_dir = os.path.join(out_dir, args.name, 'ours_default')
        cmd = 'cp -rs {} {}'.format(base_dir, their_dir)
        os.system(cmd)
        cmd = 'cp -rs {} {}'.format(base_dir, our_dir)
        os.system(cmd)
        cmd = 'cp -rs {} {}'.format(base_dir, our_default_dir)
        os.system(cmd)