Example #1
0
def get_crop_locs(lblfile, view, height, width):
    # everything is in matlab indexing
    bodylbl = apt.loadmat(lblfile)
    try:
        lsz = np.array(bodylbl['labeledpos']['size'])
        curpts = np.nan * np.ones(lsz).flatten()
        idx = np.array(bodylbl['labeledpos']['idx']) - 1
        val = np.array(bodylbl['labeledpos']['val'])
        curpts[idx] = val
        curpts = np.reshape(curpts, np.flipud(lsz))
    except IndexError:
        if bodylbl['labeledpos'].ndim == 3:
            curpts = np.array(bodylbl['labeledpos'])
            curpts = np.transpose(curpts, [2, 1, 0])
        else:
            if hasattr(bodylbl['labeledpos'][0], 'idx'):
                lsz = np.array(bodylbl['labeledpos'][0].size)
                curpts = np.nan * np.ones(lsz).flatten()
                idx = np.array(bodylbl['labeledpos'][0].idx) - 1
                val = np.array(bodylbl['labeledpos'][0].val)
                curpts[idx] = val
                curpts = np.reshape(curpts, np.flipud(lsz))
            else:
                curpts = np.array(bodylbl['labeledpos'][0])
                curpts = np.transpose(curpts, [2, 1, 0])
    neck_locs = curpts[0, :, 5 + 10 * view]
    reg_params = apt.loadmat(crop_reg_file)
    x_reg = reg_params['reg_view{}_x'.format(view + 1)]
    y_reg = reg_params['reg_view{}_y'.format(view + 1)]
    x_left = int(round(x_reg[0] + x_reg[1] * neck_locs[0]))
    x_left = 1 if x_left < 1 else x_left
    x_right = x_left + crop_size[view][0] - 1
    if x_right > width:
        x_left = width - crop_size[view][0] + 1
        x_right = width
    y_top = int(round(y_reg[0] + y_reg[1] * neck_locs[1]))
    y_top = 1 if y_top < 1 else y_top
    y_bottom = y_top + crop_size[view][1] - 1
    if y_bottom > height:
        y_bottom = height
        y_top = height - crop_size[view][1] + 1
    return [x_left, x_right, y_top, y_bottom]
def create_tfrecords(args):
    data_file = '/groups/branson/bransonlab/mayank/PoseTF/headTracking/trnDataSH_20180503_notable.mat'
    split_file = '/groups/branson/bransonlab/apt/experiments/data/trnSplits_20180509.mat'

    D = h5py.File(data_file,'r')
    S = APT_interface.loadmat(split_file)

    nims = D['IMain_crop2'].shape[1]
    movid = np.array(D['mov_id']).T - 1
    frame_num = np.array(D['frm']).T - 1
    if args['split_type'] == 'easy':
        split_arr = S['xvMain3Easy']
    else:
        split_arr = S['xvMain3Hard']


    for view in range(2):
        for split in range(3):
            args['view'] = view
            args['split_num'] = split
            conf = get_conf(args)
            if not os.path.exists(conf.cachedir):
                os.mkdir(conf.cachedir)

            outdir = conf.cachedir
            train_filename = os.path.join(outdir,conf.trainfilename)
            val_filename = os.path.join(outdir,conf.valfilename)
            env = tf.python_io.TFRecordWriter(train_filename + '.tfrecords')
            val_env = tf.python_io.TFRecordWriter(val_filename + '.tfrecords')
            splits = [[], []]
            all_locs = np.array(D['xyLblMain_crop2'])[view,:,:,:]
            for indx in range(nims):
                cur_im = np.array(D[D['IMain_crop2'][view,indx]]).T
                cur_locs = all_locs[...,indx].T
                mov_num = movid[indx]
                cur_frame_num = frame_num[indx]

                if split_arr[indx,split] == 1:
                    cur_env = val_env
                    splits[1].append([indx,cur_frame_num[0],0])
                else:
                    cur_env = env
                    splits[0].append([indx,cur_frame_num[0],0])


                im_raw = cur_im.tostring()
                example = tf.train.Example(features=tf.train.Features(feature={
                    'height': int64_feature(cur_im.shape[0]),
                    'width': int64_feature(cur_im.shape[1]),
                    'depth': int64_feature(1),
                    'trx_ndx': int64_feature(0),
                    'locs': float_feature(cur_locs.flatten()),
                    'expndx':float_feature(mov_num),
                    'ts':float_feature(cur_frame_num[0]),
                    'image_raw':bytes_feature(im_raw)
                }))


                cur_env.write(example.SerializeToString())
            env.close()
            val_env.close()
            with open(os.path.join(outdir, 'splitdata.json'), 'w') as f:
                json.dump(splits, f)

    D.close()
def create_tfrecords(args):
    data_file = '/groups/branson/bransonlab/mayank/PoseTF/headTracking/trnDataSH_20180503_notable.mat'
    split_file = '/groups/branson/bransonlab/apt/experiments/data/trnSplits_20180509.mat'

    D = h5py.File(data_file,'r')
    S = APT_interface.loadmat(split_file)

    nims = D['IMain_crop2'].shape[1]
    movid = np.array(D['mov_id']).T - 1
    frame_num = np.array(D['frm']).T - 1
    if args['split_type'] == 'easy':
        split_arr = S['xvMain3Easy']
    else:
        split_arr = S['xvMain3Hard']


    for view in range(2):
        for split in range(3):
            args['view'] = view
            args['split_num'] = split
            conf = get_conf(args)
            if not os.path.exists(conf.cachedir):
                os.mkdir(conf.cachedir)

            outdir = conf.cachedir
            train_filename = os.path.join(outdir,conf.trainfilename)
            val_filename = os.path.join(outdir,conf.valfilename)
            env = tf.python_io.TFRecordWriter(train_filename + '.tfrecords')
            val_env = tf.python_io.TFRecordWriter(val_filename + '.tfrecords')
            splits = [[], []]
            all_locs = np.array(D['xyLblMain_crop2'])[view,:,:,:]
            for indx in range(nims):
                cur_im = np.array(D[D['IMain_crop2'][view,indx]]).T
                cur_locs = all_locs[...,indx].T
                mov_num = movid[indx]
                cur_frame_num = frame_num[indx]

                if split_arr[indx,split] == 1:
                    cur_env = val_env
                    splits[1].append([indx,cur_frame_num[0],0])
                else:
                    cur_env = env
                    splits[0].append([indx,cur_frame_num[0],0])


                im_raw = cur_im.tostring()
                example = tf.train.Example(features=tf.train.Features(feature={
                    'height': int64_feature(cur_im.shape[0]),
                    'width': int64_feature(cur_im.shape[1]),
                    'depth': int64_feature(1),
                    'trx_ndx': int64_feature(0),
                    'locs': float_feature(cur_locs.flatten()),
                    'expndx':float_feature(mov_num),
                    'ts':float_feature(cur_frame_num[0]),
                    'image_raw':bytes_feature(im_raw)
                }))


                cur_env.write(example.SerializeToString())
            env.close()
            val_env.close()
            with open(os.path.join(outdir, 'splitdata.json'), 'w') as f:
                json.dump(splits, f)

    D.close()