def get_crop_locs(lblfile, view, height, width): # everything is in matlab indexing bodylbl = apt.loadmat(lblfile) try: lsz = np.array(bodylbl['labeledpos']['size']) curpts = np.nan * np.ones(lsz).flatten() idx = np.array(bodylbl['labeledpos']['idx']) - 1 val = np.array(bodylbl['labeledpos']['val']) curpts[idx] = val curpts = np.reshape(curpts, np.flipud(lsz)) except IndexError: if bodylbl['labeledpos'].ndim == 3: curpts = np.array(bodylbl['labeledpos']) curpts = np.transpose(curpts, [2, 1, 0]) else: if hasattr(bodylbl['labeledpos'][0], 'idx'): lsz = np.array(bodylbl['labeledpos'][0].size) curpts = np.nan * np.ones(lsz).flatten() idx = np.array(bodylbl['labeledpos'][0].idx) - 1 val = np.array(bodylbl['labeledpos'][0].val) curpts[idx] = val curpts = np.reshape(curpts, np.flipud(lsz)) else: curpts = np.array(bodylbl['labeledpos'][0]) curpts = np.transpose(curpts, [2, 1, 0]) neck_locs = curpts[0, :, 5 + 10 * view] reg_params = apt.loadmat(crop_reg_file) x_reg = reg_params['reg_view{}_x'.format(view + 1)] y_reg = reg_params['reg_view{}_y'.format(view + 1)] x_left = int(round(x_reg[0] + x_reg[1] * neck_locs[0])) x_left = 1 if x_left < 1 else x_left x_right = x_left + crop_size[view][0] - 1 if x_right > width: x_left = width - crop_size[view][0] + 1 x_right = width y_top = int(round(y_reg[0] + y_reg[1] * neck_locs[1])) y_top = 1 if y_top < 1 else y_top y_bottom = y_top + crop_size[view][1] - 1 if y_bottom > height: y_bottom = height y_top = height - crop_size[view][1] + 1 return [x_left, x_right, y_top, y_bottom]
def create_tfrecords(args): data_file = '/groups/branson/bransonlab/mayank/PoseTF/headTracking/trnDataSH_20180503_notable.mat' split_file = '/groups/branson/bransonlab/apt/experiments/data/trnSplits_20180509.mat' D = h5py.File(data_file,'r') S = APT_interface.loadmat(split_file) nims = D['IMain_crop2'].shape[1] movid = np.array(D['mov_id']).T - 1 frame_num = np.array(D['frm']).T - 1 if args['split_type'] == 'easy': split_arr = S['xvMain3Easy'] else: split_arr = S['xvMain3Hard'] for view in range(2): for split in range(3): args['view'] = view args['split_num'] = split conf = get_conf(args) if not os.path.exists(conf.cachedir): os.mkdir(conf.cachedir) outdir = conf.cachedir train_filename = os.path.join(outdir,conf.trainfilename) val_filename = os.path.join(outdir,conf.valfilename) env = tf.python_io.TFRecordWriter(train_filename + '.tfrecords') val_env = tf.python_io.TFRecordWriter(val_filename + '.tfrecords') splits = [[], []] all_locs = np.array(D['xyLblMain_crop2'])[view,:,:,:] for indx in range(nims): cur_im = np.array(D[D['IMain_crop2'][view,indx]]).T cur_locs = all_locs[...,indx].T mov_num = movid[indx] cur_frame_num = frame_num[indx] if split_arr[indx,split] == 1: cur_env = val_env splits[1].append([indx,cur_frame_num[0],0]) else: cur_env = env splits[0].append([indx,cur_frame_num[0],0]) im_raw = cur_im.tostring() example = tf.train.Example(features=tf.train.Features(feature={ 'height': int64_feature(cur_im.shape[0]), 'width': int64_feature(cur_im.shape[1]), 'depth': int64_feature(1), 'trx_ndx': int64_feature(0), 'locs': float_feature(cur_locs.flatten()), 'expndx':float_feature(mov_num), 'ts':float_feature(cur_frame_num[0]), 'image_raw':bytes_feature(im_raw) })) cur_env.write(example.SerializeToString()) env.close() val_env.close() with open(os.path.join(outdir, 'splitdata.json'), 'w') as f: json.dump(splits, f) D.close()