Exemplo n.º 1
0
def get_increasing_splits(conf, split_type='random'):
    # creates json files for the xv splits
    local_dirs, _ = multiResData.find_local_dirs(conf)
    lbl = h5py.File(conf.labelfile, 'r')

    info = []
    n_labeled_frames = 0
    for ndx, dir_name in enumerate(local_dirs):
        if conf.has_trx_file:
            trx_files = multiResData.get_trx_files(lbl, local_dirs)
            trx = sio.loadmat(trx_files[ndx])['trx'][0]
            n_trx = len(trx)
        else:
            n_trx = 1

        for trx_ndx in range(n_trx):
            frames = multiResData.get_labeled_frames(lbl, ndx, trx_ndx)
            ts = get_label_ts(lbl, ndx, trx_ndx, frames)
            mm = [ndx] * frames.size
            tt = [trx_ndx] * frames.size
            cur_trx_info = list(zip(mm, tt, frames.tolist(), ts.tolist()))
            info.extend(cur_trx_info)
            n_labeled_frames += frames.size
    lbl.close()

    if split_type == 'time':
        info = sorted(info, key=lambda x: x[3])
    elif split_type == 'random':
        info = random.shuffle(info)
    else:
        raise ValueError('Incorrect split type for prog')

    return info
Exemplo n.º 2
0
def createvaldataJan(conf):
    from janLegConfig import conf
    ##
    L  = h5py.File(conf.labelfile)
    localdirs,seldirs = multiResData.find_local_dirs(conf)

    ##
    pts = L['labeledpos']
    nmov = len(localdirs)
    fly_id = []
    num_labels = np.zeros([nmov,])
    for ndx in range(nmov):

        dstr = localdirs[ndx].split('/')
        fly_id.append(dstr[-3])
        curpts = np.array(L[pts[0,ndx]])
        frames = np.where(np.invert(np.all(np.isnan(curpts[:, :, :]), axis=(1, 2))))[0]
        num_labels[ndx] = len(frames)
    ##

    ufly = list(set(fly_id))
    fly_labels = np.zeros([len(ufly)])
    for ndx in range(len(ufly)):
        fly_ndx = [i for i, x in enumerate(fly_id) if x == ufly[ndx]]
        fly_labels[ndx] = np.sum(num_labels[fly_ndx])

    ##

    folds = 3
    lbls_fold = int(old_div(np.sum(num_labels),folds))

    imbalance = True
    while imbalance:
        per_fold = np.zeros([folds])
        fly_fold = np.zeros(len(ufly))
        for ndx in range(len(ufly)):
            done = False
            curfold = np.random.randint(folds)
            while not done:
                if per_fold[curfold]>lbls_fold:
                    curfold = (curfold+1)%folds
                else:
                    fly_fold[ndx] = curfold
                    per_fold[curfold] += fly_labels[ndx]
                    done = True
        imbalance = (per_fold.max()-per_fold.min())>(old_div(lbls_fold,3))
    print(per_fold)

##
    for ndx in range(folds):
        curvaldatafilename = os.path.join(conf.cachedir,conf.valdatafilename + '_fold_{}'.format(ndx))
        fly_val = np.where(fly_fold==ndx)[0]
        isval = []
        for ndx in range(len(fly_val)):
            isval += [i for i, x in enumerate(fly_id) if x == ufly[fly_val[ndx]]]

        with open(curvaldatafilename,'w') as f:
            pickle.dump([isval,localdirs,seldirs],f)
Exemplo n.º 3
0
def test_crop():
    import trackStephenHead_KB as ts
    import APT_interface as apt
    import multiResData
    import cv2
    from cvc import cvc
    import os
    import re
    import hdf5storage
    crop_reg_file = '/groups/branson/bransonlab/mayank/stephen_copy/crop_regression_params.mat'
    # lbl_file = '/groups/branson/bransonlab/mayank/stephen_copy/apt_cache/sh_trn4523_gtcomplete_cacheddata_bestPrms20180920_retrain20180920T123534_withGTres.lbl'
    lbl_file = '/groups/branson/bransonlab/mayank/stephen_copy/apt_cache/sh_trn4879_gtcomplete.lbl'
    crop_size = [[230, 350], [350, 350]]
    name = 'stephen_20181029'
    cache_dir = '/groups/branson/bransonlab/mayank/stephen_copy/apt_cache'
    bodylblfile = '/groups/branson/bransonlab/mayank/stephen_copy/fly2BodyAxis_lookupTable_Ben.csv'
    import h5py
    bodydict = {}
    f = open(bodylblfile, 'r')
    for l in f:
        lparts = l.split(',')
        if len(lparts) != 2:
            print("Error splitting body label file line %s into two parts" % l)
            raise exit(0)
        bodydict[int(lparts[0])] = lparts[1].strip()
    f.close()

    flynums = [[], []]
    crop_locs = [[], []]
    for view in range(2):
        conf = apt.create_conf(lbl_file,
                               view,
                               'aa',
                               cache_dir='/groups/branson/home/kabram/temp')
        movs = multiResData.find_local_dirs(conf)[0]

        for mov in movs:
            dirname = os.path.normpath(mov)
            dir_parts = dirname.split(os.sep)
            aa = re.search('fly_*(\d+)', dir_parts[-3])
            flynum = int(aa.groups()[0])
            if bodydict.has_key(flynum):
                cap = cv2.VideoCapture(mov)
                height = int(cap.get(cvc.FRAME_HEIGHT))
                width = int(cap.get(cvc.FRAME_WIDTH))
                cap.release()
                crop_locs[view].append(
                    ts.get_crop_locs(bodydict[flynum], view, height,
                                     width))  # return x first
                flynums[view].append(flynum)

    hdf5storage.savemat(
        '/groups/branson/bransonlab/mayank/stephen_copy/auto_crop_locs_trn4879',
        {
            'flynum': flynums,
            'crop_locs': crop_locs
        })
Exemplo n.º 4
0
def duplicatesJan():
    ##
    from janLegConfig import conf
    import collections
    localdirs,seldirs = multiResData.find_local_dirs(conf)
    gg = [conf.getexpname(x) for x in localdirs]
    mm = [item for item, count in list(collections.Counter(gg).items()) if count > 1]
    L = h5py.File(conf.labelfile)
    pts = L['labeledpos']
    for ndx in range(len(mm)):
        aa = [idx for idx, x in enumerate(gg) if x == mm[ndx]]
        print(mm[ndx])
        for curndx in aa:
            curpts = np.array(L[pts[0, curndx]])
            frames = np.where(np.invert( np.all(np.isnan(curpts[:,:,:]),axis=(1,2))))[0]
            print(curndx,frames.size)
Exemplo n.º 5
0
def duplicatesJan():
    ##
    from janLegConfig import conf
    import collections
    localdirs, seldirs = multiResData.find_local_dirs(conf)
    gg = [conf.getexpname(x) for x in localdirs]
    mm = [
        item for item, count in list(collections.Counter(gg).items())
        if count > 1
    ]
    L = h5py.File(conf.labelfile)
    pts = L['labeledpos']
    for ndx in range(len(mm)):
        aa = [idx for idx, x in enumerate(gg) if x == mm[ndx]]
        print(mm[ndx])
        for curndx in aa:
            curpts = np.array(L[pts[0, curndx]])
            frames = np.where(
                np.invert(np.all(np.isnan(curpts[:, :, :]), axis=(1, 2))))[0]
            print(curndx, frames.size)
Exemplo n.º 6
0
def createvaldataJay():
    ##
    from jayMouseConfig import sideconf as conf
    from jayMouseConfig import frontconf
    ##
    L  = h5py.File(conf.labelfile)
    localdirs,seldirs = multiResData.find_local_dirs(conf)

    ##
    pts = L['labeledpos']
    nmov = len(localdirs)
    fly_id = []
    date_id = []
    num_labels = np.zeros([nmov,])
    for ndx in range(nmov):

        dstr = localdirs[ndx].split('/')
        if re.search('M(\d+)_(\d{8})_', dstr[-2]):
            mm = re.search('M(\d+)_(\d{8})_', dstr[-2]).groups()
            aa = '_'.join(mm)
        else:
            mm = re.search('M(\d+)_(\d{8})_', dstr[-1]).groups()
            aa = '_'.join(mm)
        fly_id.append(aa)
        date_id.append(aa[1])
        curpts = np.array(L[pts[0,ndx]])
        frames = np.where(np.invert(np.all(np.isnan(curpts[:, :, :]), axis=(1, 2))))[0]
        num_labels[ndx] = len(frames)
    ##

    ufly = list(set(fly_id))
    fly_labels = np.zeros([len(ufly)])
    for ndx in range(len(ufly)):
        fly_ndx = [i for i, x in enumerate(fly_id) if x == ufly[ndx]]
        fly_labels[ndx] = np.sum(num_labels[fly_ndx])

    ##

    folds = 3
    lbls_fold = int(old_div(np.sum(num_labels),folds))

    imbalance = True
    while imbalance:
        per_fold = np.zeros([folds])
        fly_fold = np.zeros(len(ufly))
        for ndx in range(len(ufly)):
            done = False
            curfold = np.random.randint(folds)
            while not done:
                if per_fold[curfold]>lbls_fold:
                    curfold = (curfold+1)%folds
                else:
                    fly_fold[ndx] = curfold
                    per_fold[curfold] += fly_labels[ndx]
                    done = True
        imbalance = (per_fold.max()-per_fold.min())>(old_div(lbls_fold,2.5))
    print(per_fold)

##
    allisval = []
    localdirsf,seldirsf = multiResData.find_local_dirs(frontconf)

    for ndx in range(folds):
        curvaldatafilename = os.path.join(conf.cachedir,conf.valdatafilename + '_fold_{}'.format(ndx))
        curvaldatafilenamef = os.path.join(frontconf.cachedir,frontconf.valdatafilename + '_fold_{}'.format(ndx))
        fly_val = np.where(fly_fold==ndx)[0]
        isval = []
        for indx in range(len(fly_val)):
            isval += [i for i, x in enumerate(fly_id) if x == ufly[fly_val[indx]]]
        allisval.append(isval)
        with open(curvaldatafilename,'w') as f:
            pickle.dump([isval,localdirs,seldirs],f)
        with open(curvaldatafilenamef,'w') as f:
            pickle.dump([isval,localdirsf,seldirsf],f)
        io.savemat('/groups/branson/bransonlab/mayank/PoseTF/data/jayMouse/valSplits.mat', {'split': allisval})
Exemplo n.º 7
0
def createvaldata():
    ##
    conf = side1conf
    L  = h5py.File(conf.labelfile)
    localdirs,seldirs = multiResData.find_local_dirs(conf)
    localdirs2,seldirs2 = multiResData.find_local_dirs(side2conf)
    localdirsb,seldirsb = multiResData.find_local_dirs(bottomconf)

    ##
    pts = L['labeledpos']
    nmov = len(localdirs)
    fly_id = []
    num_labels = np.zeros([nmov,])
    for ndx in range(nmov):

        dstr = localdirs[ndx].split('/')[-2]
        fly_id.append(dstr)
        curpts = np.array(L[pts[0,ndx]])
        frames = np.where(np.invert(np.all(np.isnan(curpts[:, :, :]), axis=(1, 2))))[0]
        num_labels[ndx] = len(frames)
    ##

    ufly = list(set(fly_id[1:])) # keep movie 1 aside
    fly_labels = np.zeros([len(ufly)])
    for ndx in range(len(ufly)):
        fly_ndx = [i for i, x in enumerate(fly_id) if x == ufly[ndx]]
        fly_labels[ndx] = np.sum(num_labels[fly_ndx])

    ##

    folds = 2
    lbls_fold = int(old_div(np.sum(num_labels[1:]),folds)) # keep movie 1 common

    imbalance = True
    while imbalance:
        per_fold = np.zeros([folds])
        fly_fold = np.zeros(len(ufly))
        for ndx in range(len(ufly)):
            done = False
            curfold = np.random.randint(folds)
            while not done:
                if per_fold[curfold]>lbls_fold:
                    curfold = (curfold+1)%folds
                else:
                    fly_fold[ndx] = curfold
                    per_fold[curfold] += fly_labels[ndx]
                    done = True
        imbalance = (per_fold.max()-per_fold.min())>(old_div(lbls_fold,3))
    print(per_fold)

##
    for ndx in range(folds):
        curvaldatafilename = os.path.join(conf.cachedir,conf.valdatafilename + '_fold_{}'.format(ndx))
        curvaldatafilename2 = os.path.join(side2conf.cachedir,side2conf.valdatafilename + '_fold_{}'.format(ndx))
        curvaldatafilenameb = os.path.join(bottomconf.cachedir,bottomconf.valdatafilename + '_fold_{}'.format(ndx))
        fly_val = np.where(fly_fold==ndx)[0]
        isval = []
        for ix in range(len(fly_val)):
            isval += [i for i, x in enumerate(fly_id) if x == ufly[fly_val[ix]]]
        with open(curvaldatafilename,'w') as f:
            pickle.dump([isval,localdirs,seldirs],f)
        with open(curvaldatafilename2,'w') as f:
            pickle.dump([isval,localdirs2,seldirs2],f)
        with open(curvaldatafilenameb,'w') as f:
            pickle.dump([isval,localdirsb,seldirsb],f)
Exemplo n.º 8
0
    def create_db(self, split_file=None):
        assert self.rnn_pp_hist % self.conf.batch_size == 0, 'make sure the history is a multiple of batch size'
        # assert len(self.conf.mdn_groups)==1, 'This works only for single group. check for line 118'
        net = PoseUNet_resnet.PoseUMDN_resnet(self.conf, self.mdn_name)
        pred_fn, close_fn, _ = net.get_pred_fn()

        conf = self.conf
        on_gt = False
        db_files = ()
        if split_file is not None:
            self.conf.splitType = 'predefined'
            predefined = PoseTools.json_load(split_file)
            split = True
        else:
            predefined = None
            split = False

        mov_split = None

        local_dirs, _ = multiResData.find_local_dirs(conf, on_gt=False)
        lbl = h5py.File(conf.labelfile, 'r')
        view = conf.view
        flipud = conf.flipud
        npts_per_view = np.array(lbl['cfg']['NumLabelPoints'])[0, 0]
        sel_pts = int(view * npts_per_view) + conf.selpts

        out_fns = [True, False]
        data = [[], []]
        count = 0
        for ndx, dir_name in enumerate(local_dirs):

            cur_pts = multiResData.trx_pts(lbl, ndx, on_gt)
            crop_loc = PoseTools.get_crop_loc(lbl, ndx, view, on_gt)
            cap = movies.Movie(dir_name)

            if conf.has_trx_file:
                trx_files = multiResData.get_trx_files(lbl, local_dirs, on_gt)
                trx = sio.loadmat(trx_files[ndx])['trx'][0]
                n_trx = len(trx)
                trx_split = np.random.random(n_trx) < conf.valratio
                first_frames = np.array([
                    x['firstframe'][0, 0] for x in trx
                ]) - 1  # for converting from 1 indexing to 0 indexing
                end_frames = np.array([
                    x['endframe'][0, 0] for x in trx
                ]) - 1  # for converting from 1 indexing to 0 indexing
            else:
                trx = [None]
                n_trx = 1
                trx_split = None
                cur_pts = cur_pts[np.newaxis, ...]

            for trx_ndx in range(n_trx):

                frames = multiResData.get_labeled_frames(
                    lbl, ndx, trx_ndx, on_gt)
                cur_trx = trx[trx_ndx]
                for fnum in frames:
                    info = [ndx, fnum, trx_ndx]
                    cur_out = multiResData.get_cur_env(out_fns,
                                                       split,
                                                       conf,
                                                       info,
                                                       mov_split,
                                                       trx_split=trx_split,
                                                       predefined=predefined)
                    num_rep = 1 + cur_out * (self.train_rep - 1)

                    orig_ims = []
                    orig_labels = []
                    for fndx in range(-self.rnn_pp_hist, self.rnn_pp_hist):
                        frame_in, cur_loc = multiResData.get_patch(
                            cap,
                            fnum,
                            conf,
                            cur_pts[trx_ndx, fnum, :, sel_pts],
                            cur_trx=cur_trx,
                            flipud=flipud,
                            crop_loc=crop_loc,
                            offset=fndx)
                        orig_labels.append(cur_loc)
                        orig_ims.append(frame_in)

                    orig_ims = np.array(orig_ims)
                    orig_labels = np.array(orig_labels)

                    for rep in range(num_rep):
                        cur_pred = np.ones([
                            self.rnn_pp_hist * 2, self.conf.n_classes, 2
                        ]) * np.nan
                        raw_preds = np.ones([
                            self.rnn_pp_hist * 2, self.conf.n_classes, 2
                        ]) * np.nan
                        unet_preds = np.ones([
                            self.rnn_pp_hist * 2, self.conf.n_classes, 2
                        ]) * np.nan

                        cur_ims, cur_labels = PoseTools.preprocess_ims(
                            orig_ims,
                            orig_labels,
                            conf,
                            distort=cur_out,
                            scale=self.conf.rescale,
                            group_sz=2 * self.rnn_pp_hist)

                        bsize = self.conf.batch_size
                        nbatches = self.rnn_pp_hist / bsize * 2
                        for bndx in range(nbatches):
                            start = bndx * bsize
                            end = (bndx + 1) * bsize

                            ret_dict = pred_fn(cur_ims[start:end, ...])
                            pred_locs = ret_dict['locs']
                            raw_preds[start:end] = ret_dict['locs']
                            unet_preds[start:end] = ret_dict['locs_unet']

                            hsz = [
                                float(self.conf.imsz[1]) / 2,
                                float(self.conf.imsz[0]) / 2
                            ]
                            if conf.has_trx_file:
                                for e_ndx in range(bsize):
                                    trx_fnum = fnum - first_frames[trx_ndx]
                                    trx_fnum_ex = fnum - first_frames[
                                        trx_ndx] + e_ndx + start - self.rnn_pp_hist
                                    trx_fnum_ex = trx_fnum_ex if trx_fnum_ex > 0 else 0
                                    end_ndx = end_frames[
                                        trx_ndx] - first_frames[trx_ndx]
                                    trx_fnum_ex = trx_fnum_ex if trx_fnum_ex < end_ndx else end_ndx
                                    temp_pred = pred_locs[e_ndx, :, :]
                                    dx = cur_trx['x'][0,
                                                      trx_fnum] - cur_trx['x'][
                                                          0, trx_fnum_ex]
                                    dy = cur_trx['y'][0,
                                                      trx_fnum] - cur_trx['y'][
                                                          0, trx_fnum_ex]
                                    # -1 for 1-indexing in matlab and 0-indexing in python
                                    tt = cur_trx['theta'][
                                        0, trx_fnum] - cur_trx['theta'][
                                            0, trx_fnum_ex]
                                    R = [[np.cos(tt), -np.sin(tt)],
                                         [np.sin(tt), np.cos(tt)]]
                                    rr = (cur_trx['theta'][0, trx_fnum]
                                          ) + math.pi / 2
                                    Q = [[np.cos(rr), -np.sin(rr)],
                                         [np.sin(rr), np.cos(rr)]]
                                    cur_locs = np.dot(temp_pred - hsz,
                                                      R) + hsz - np.dot(
                                                          [dx, dy], Q)
                                    cur_pred[start + e_ndx, ...] = cur_locs
                            else:
                                cur_pred[start:end, :, :] = pred_locs

                        # ---------- Code for testing
                        # raw_pred = np.array(raw_preds).reshape((cur_pred.shape[0],) + raw_preds[0].shape[1:])
                        # f, ax = plt.subplots(2, 2)
                        # ax = ax.flatten()
                        # ex = 32
                        # xx = multiResData.get_patch(cap, fnum, conf, cur_pts[trx_ndx, fnum, :, sel_pts],
                        #                             cur_trx=cur_trx, crop_loc=None, offset=ex, stationary=False)
                        # ax[0].imshow(xx[0][:, :, 0], 'gray')
                        # ax[0].scatter(cur_pred[-self.rnn_pp_hist + ex, :, 0], cur_pred[-self.rnn_pp_hist + ex, :, 1])
                        # xx = multiResData.get_patch(cap, fnum, conf, cur_pts[trx_ndx, fnum, :, sel_pts],
                        #                             cur_trx=cur_trx, crop_loc=None, offset=0, stationary=False)
                        # ax[1].imshow(xx[0][:, :, 0], 'gray')
                        # ax[1].scatter(cur_pred[self.rnn_pp_hist, :, 0], cur_pred[self.rnn_pp_hist, :, 1])
                        # ax[2].imshow(cur_ims[-self.rnn_pp_hist + ex, :, :, 0], 'gray')
                        # ax[2].scatter(raw_pred[-self.rnn_pp_hist + ex, :, 0], raw_pred[-self.rnn_pp_hist + ex, :, 1])
                        # ax[3].imshow(cur_ims[self.rnn_pp_hist, :, :, 0], 'gray')
                        # ax[3].scatter(raw_pred[self.rnn_pp_hist, :, 0], raw_pred[self.rnn_pp_hist, :, 1])

                        # ---------- Code for testing II
                        # uu = np.array(unet_pred)
                        # uu = uu.reshape((-1,) + uu.shape[2:])
                        # import PoseTools
                        # unet_locs = PoseTools.get_pred_locs(uu)
                        # ss = np.sqrt(np.sum((unet_locs[rx, ...] - cur_labels[rx, ...]) ** 2, axis=-1))
                        # zz = np.diff(cur_pred, axis=0)
                        # rx = self.rnn_pp_hist
                        # dd = np.sqrt(np.sum((cur_pred[rx, ...] - cur_labels[rx, ...]) ** 2, axis=-1))
                        # print dd.max(), dd.argmax()
                        # ix = dd.argmax()
                        # n_ex = 3
                        # f, ax = plt.subplots(2, 4 + n_ex, figsize=(18, 12))
                        # ax = ax.T.flatten()
                        # ax[0].plot(cur_pred[:, ix, 0], cur_pred[:, ix, 1])
                        # ax[1].plot(zz[:, ix, 0], zz[:, ix, 1])
                        # ax[2].plot(zz[:, ix, 0])
                        # ax[3].plot(zz[:, ix, 1])
                        # ax[4].imshow(cur_ims[..., 0].mean(axis=0), 'gray')
                        # ax[5].imshow(cur_ims[..., 0].min(axis=0), 'gray')
                        # ax[6].imshow(cur_ims[self.rnn_pp_hist, ..., 0], 'gray')
                        # ax[6].scatter(cur_pred[self.rnn_pp_hist, :, 0], cur_pred[self.rnn_pp_hist, :, 1])
                        # ax[7].imshow(cur_ims[self.rnn_pp_hist, ..., 0], 'gray')
                        # ax[7].scatter(cur_labels[self.rnn_pp_hist, :, 0], cur_labels[self.rnn_pp_hist, :, 1])
                        # plt.title('{}'.format(ix))
                        # for xx in range(2 * n_ex):
                        #     xxi = multiResData.get_patch(cap, fnum, conf, cur_pts[trx_ndx, fnum, :, sel_pts],
                        #                                  cur_trx=cur_trx, crop_loc=None, offset=-xx, stationary=False)
                        #     ax[8 + xx].imshow(xxi[0][:, :, 0], 'gray')
                        #     ax[8 + xx].scatter(cur_pred[self.rnn_pp_hist - xx, ix, 0],
                        #                        cur_pred[self.rnn_pp_hist - xx, ix, 1])
                        #     if xx is 0:
                        #         ax[8 + xx].scatter(cur_pred[self.rnn_pp_hist - n_ex:, ix, 0],
                        #                            cur_pred[self.rnn_pp_hist - n_ex:, ix, 1])
                        #
                        rx = self.rnn_pp_hist
                        dd = np.sqrt(
                            np.sum(
                                (cur_pred[rx, ...] - cur_labels[rx, ...])**2,
                                axis=-1))
                        cur_info = [ndx, fnum, trx_ndx]
                        raw_preds = np.array(raw_preds)
                        if cur_out:
                            data[0].append([
                                cur_pred, cur_labels[rx, ...], cur_info,
                                raw_preds
                            ])
                        else:
                            data[1].append([
                                cur_pred, cur_labels[rx, ...], cur_info,
                                raw_preds
                            ])
                        count += 1

                    if count % 50 == 0:
                        sys.stdout.write('.')
                        with open(
                                os.path.join(conf.cachedir,
                                             self.data_name + '.p'), 'w') as f:
                            pickle.dump(data, f)
                    if count % 2000 == 0:
                        sys.stdout.write('\n')

            cap.close()  # close the movie handles

        close_fn()
        with open(os.path.join(conf.cachedir, self.data_name + '.p'),
                  'w') as f:
            pickle.dump(data, f)
        lbl.close()
Exemplo n.º 9
0
def createvaldataJay():
    ##
    from jayMouseConfig import sideconf as conf
    from jayMouseConfig import frontconf
    ##
    L = h5py.File(conf.labelfile)
    localdirs, seldirs = multiResData.find_local_dirs(conf)

    ##
    pts = L['labeledpos']
    nmov = len(localdirs)
    fly_id = []
    date_id = []
    num_labels = np.zeros([
        nmov,
    ])
    for ndx in range(nmov):

        dstr = localdirs[ndx].split('/')
        if re.search('M(\d+)_(\d{8})_', dstr[-2]):
            mm = re.search('M(\d+)_(\d{8})_', dstr[-2]).groups()
            aa = '_'.join(mm)
        else:
            mm = re.search('M(\d+)_(\d{8})_', dstr[-1]).groups()
            aa = '_'.join(mm)
        fly_id.append(aa)
        date_id.append(aa[1])
        curpts = np.array(L[pts[0, ndx]])
        frames = np.where(
            np.invert(np.all(np.isnan(curpts[:, :, :]), axis=(1, 2))))[0]
        num_labels[ndx] = len(frames)
    ##

    ufly = list(set(fly_id))
    fly_labels = np.zeros([len(ufly)])
    for ndx in range(len(ufly)):
        fly_ndx = [i for i, x in enumerate(fly_id) if x == ufly[ndx]]
        fly_labels[ndx] = np.sum(num_labels[fly_ndx])

    ##

    folds = 3
    lbls_fold = int(old_div(np.sum(num_labels), folds))

    imbalance = True
    while imbalance:
        per_fold = np.zeros([folds])
        fly_fold = np.zeros(len(ufly))
        for ndx in range(len(ufly)):
            done = False
            curfold = np.random.randint(folds)
            while not done:
                if per_fold[curfold] > lbls_fold:
                    curfold = (curfold + 1) % folds
                else:
                    fly_fold[ndx] = curfold
                    per_fold[curfold] += fly_labels[ndx]
                    done = True
        imbalance = (per_fold.max() - per_fold.min()) > (old_div(
            lbls_fold, 2.5))
    print(per_fold)

    ##
    allisval = []
    localdirsf, seldirsf = multiResData.find_local_dirs(frontconf)

    for ndx in range(folds):
        curvaldatafilename = os.path.join(
            conf.cachedir, conf.valdatafilename + '_fold_{}'.format(ndx))
        curvaldatafilenamef = os.path.join(
            frontconf.cachedir,
            frontconf.valdatafilename + '_fold_{}'.format(ndx))
        fly_val = np.where(fly_fold == ndx)[0]
        isval = []
        for indx in range(len(fly_val)):
            isval += [
                i for i, x in enumerate(fly_id) if x == ufly[fly_val[indx]]
            ]
        allisval.append(isval)
        with open(curvaldatafilename, 'w') as f:
            pickle.dump([isval, localdirs, seldirs], f)
        with open(curvaldatafilenamef, 'w') as f:
            pickle.dump([isval, localdirsf, seldirsf], f)
        io.savemat(
            '/groups/branson/bransonlab/mayank/PoseTF/data/jayMouse/valSplits.mat',
            {'split': allisval})