def plot_nstep_errors(models, dtype, video_list, t_dim=50, gender=0, t0=21224, tsim=15, \ visionF=1, labels=None, vlowmax=None, colors=['blue','red','green', 'magenta', 'purple', 'black']): if labels is None: labels = models fname0 = 'eyrun_simulate_data.mat' basepath = '/groups/branson/home/bransonk/behavioranalysis/code/SSRNN/SSRNN/Data/bowl/' parentpath = '/groups/branson/home/imd/Documents/janelia/research/fly_behaviour_sim/71g01' #ftag = ['velo', 'pos', 'bodyAng', 'wingang'] ftag = ['velo', 'pos'] for j in range(len(ftag)): pos_errs = [] pos_stds = [] #for testvideo_num, vpath in enumerate(video_list[TEST]): for testvideo_num in range(0, len(video_list[TEST])): vpath = video_list[TEST][testvideo_num] print('testvideo %d %s' % (testvideo_num, video_list[TEST][testvideo_num])) matfile = basepath + vpath + fname0 (trx, motiondata, params, basesize) = load_eyrun_data(matfile) t1 = trx['x'].shape[0] - tsim male_ind, female_ind = gender_classify(basesize['majax']) pos_err_models = [] pos_std_models = [] for mtype0 in models: if 'const' == mtype0 or 'copy' == mtype0 or 'zero' == mtype0: fname = './metrics/' + vpath + '/' + mtype0 + '/' + mtype0 + '_visionF' + str( visionF) + '_' + str(t0) + 't0_' + str( t1) + 't1_30tsim.npy' err_test = np.load(fname) else: err_tests = [] for kk in range(10): if 'rnn' in mtype0: fname = './metrics/' + vpath + '/' + mtype0 + '/' + mtype0 + '_' + str( t0) + 't0_' + str(t1) + 't1_30tsim_150000epoch' elif 'skip' in mtype0: fname = './metrics/' + vpath + '/' + mtype0 + '/' + mtype0 + '_' + str( t0) + 't0_' + str(t1) + 't1_30tsim_150000epoch' else: fname = './metrics/' + vpath + '/' + mtype0 + '/' + mtype0 + '_visionF' + str( visionF) + '_' + str(t0) + 't0_' + str( t1) + 't1_30tsim' if ('pdb' in dtype or 'gmr91' in dtype) and 'lr' in mtype0: fname += '_200000epoch' err_test0 = np.load(fname + '_%dfold.npy' % kk) #if dtype == 'gmr91': err_test0 = err_test0[:,:,:,1:] err_tests.append(err_test0) import pdb pdb.set_trace() print(fname, mtype0, err_test0.shape)
parser.add_argument('--nntype', type=str, default='conv4') parser.add_argument('--vtype', type=str, default='full') parser.add_argument('--dtype', type=str, default='pdb') parser.add_argument('--btype', type=str, default='perc') return check_args(parser.parse_args()) parentpath = '/groups/branson/home/imd/Documents/janelia/research/FlyTrajPred/pytorch/models/' trainF = 1 use_cuda = 1 if __name__ == '__main__': # parse arguments args = parse_args() args.y_dim = args.num_bin * args.m_dim matfile = '/groups/branson/home/bransonk/behavioranalysis/code/SSRNN/SSRNN/Code/eyrun_simulate_data.mat' (trx, motiondata, params, basesize) = load_eyrun_data(matfile) #bin_means = get_centre_bin(params['binedges'].T, tr_config['num_bin']) if trainF: model = trainIters(args, print_every=100) else: pass print(model) #print(tr_config['savepath']) #python -u main_conv.py --save_dir ./runs/conv4_cat50_relu/ --epoch 35000 --gender 0 --vtype full --visionOnly 0 --vision 1 --lr 0.01 --h_dim 128 --t_dim 50 --atype 'relu'
def simulate_flies( args, real_male_flies, real_female_flies, \ simulated_male_flies, simulated_female_flies,\ male_model=None, female_model=None, plottrxlen=1,\ t0=0, t1=30320, vision_save=False, histoF=False, t_dim=50,\ bookkeepingF=True, fname='small', \ burning=100,DEBUG=1, testvideo_num=0,\ sim_type='SMSF', batch_sz=32,\ visionOnlyF=0, vpath='', ftag='',\ fly_single_ind=0 ): from util_fly import compute_vision, motion2binidx, \ binscores2motion, update_position from gen_dataset import video16_path if vpath == '': vpath = video16_path[args.dtype][TEST][testvideo_num] fname = 'eyrun_simulate_data.mat' basepath = '/groups/branson/home/bransonk/behavioranalysis/code/SSRNN/SSRNN/Data/bowl/' vision_matfile = basepath + vpath + 'movie-vision.mat' vc_data = load_vision(vision_matfile)[1:] matfile = basepath + vpath + fname trx, motiondata, params, basesize = load_eyrun_data(matfile) binedges = params['binedges'] params['mtype'] = args.mtype n_flies = trx['x'].shape[1] T = trx['x'].shape[0] # initial pose x = trx['x'][t0, :].copy() y = trx['y'][t0, :].copy() theta = trx['theta'][t0, :].copy() b = basesize['minax'].copy() a = trx['a'][t0, :].copy() l_wing_ang = trx['l_wing_ang'][t0, :].copy() r_wing_ang = trx['r_wing_ang'][t0, :].copy() l_wing_len = trx['l_wing_len'][t0, :].copy() r_wing_len = trx['r_wing_len'][t0, :].copy() xprev = x.copy() yprev = y.copy() thetaprev = theta.copy() #NUM_FLY = if 'conv' in args.mtype: state = np.zeros((n_flies, NUM_VFEAT + NUM_MFEAT, t_dim)) if t0 > 50: state[:, :8, :] = motiondata[:, t0 - 50:t0, :].transpose(2, 0, 1) state[:, 8:, :] = vc_data[t0 - 50:t0, :, :].transpose(1, 2, 0) ### MANIPULATION PURPOSE #fly_j=6 #zeroout_ind = [39, 31, 29, 38, 33, 32, 12, 36, 9, 42,\ # 35, 44, 49, 47, 34, 6, 8, 7, 27, 46, 45, 26, 48, 28, 24, 25] #vision_feat = state[fly_j:,8:8+72,:] #state[fly_j:,8:8+72,zeroout_ind] = 0.15 elif visionOnlyF: state = np.zeros((n_flies, NUM_VFEAT, t_dim)) else: state = np.zeros((n_flies, NUM_MFEAT, t_dim)) feat_motion = motiondata[:, t0, :].copy() male_ind, female_ind = gender_classify(basesize['majax']) #male_ind, female_ind =simulated_male_flies, simulated_female_flies mymotiondata = np.zeros(motiondata.shape) simtrx = {} #tsim = T#min(t1,T)-t0-1 tsim = T if t1 is None else t1 #tsim = t1-t0-1 simtrx['x'] = np.concatenate( (trx['x'][:t0 + 1, :], np.zeros((tsim, n_flies)))) simtrx['y'] = np.concatenate( (trx['y'][:t0 + 1, :], np.zeros((tsim, n_flies)))) simtrx['theta'] = np.concatenate( (trx['theta'][:t0 + 1, :], np.zeros((tsim, n_flies)))) simtrx['a'] = np.concatenate( (trx['a'][:t0 + 1, :], np.zeros((tsim, n_flies)))) simtrx['b'] = np.concatenate( (trx['b'][:t0 + 1, :], np.zeros((tsim, n_flies)))) simtrx['l_wing_ang'] = np.concatenate( (trx['l_wing_ang'][:t0 + 1, :], np.zeros((tsim, n_flies)))) simtrx['r_wing_ang'] = np.concatenate( (trx['r_wing_ang'][:t0 + 1, :], np.zeros((tsim, n_flies)))) simtrx['l_wing_len'] = np.concatenate( (trx['l_wing_len'][:t0 + 1, :], np.zeros((tsim, n_flies)))) simtrx['r_wing_len'] = np.concatenate( (trx['r_wing_len'][:t0 + 1, :], np.zeros((tsim, n_flies)))) ##Fly Visualization Initialization vdata, mdata, bdata = [], [], [] if DEBUG >= 1: fig = plt.figure(figsize=(15, 15)) ax = plt.axes([0, 0, 1, 1]) #colors = get_default_fly_colors(n_flies) if sim_type == 'Single': fly_j = simulated_male_flies[0] colors = get_default_fly_colors_single(fly_j, n_flies) print(colors) elif sim_type == 'LOO': colors = get_default_fly_colors_black(n_flies) else: colors = get_default_fly_colors_rb(n_flies + 1) hbg = plt.imshow(params['bg'], cmap=cm.gray, vmin=0., vmax=1.) htrx = [] for fly in range(n_flies): htrxcurr, = ax.plot(x[fly], y[fly], '-', color=np.append(colors[fly, :-1], .5), linewidth=3) htrx.append(htrxcurr) hbodies, hflies, htexts = draw_flies(x, y, a, b, theta, l_wing_ang, r_wing_ang, l_wing_len, r_wing_len, ax=ax, colors=colors, textOff=True) #plt.axis('image') counter_plt = plt.annotate( '{:.2f}sec'.format(0. / default_params['FPS']), xy=[1024 - 55, params['bg'].shape[0] - 45], xytext=(0, 3), # 3 points vertical offset textcoords="offset points", ha='center', va='bottom', size=18, color='black') plt.axis('off') load_start = time.time() male_model, female_model = model_selection( args, male_model, \ female_model, params, \ model_epoch=args.model_epoch,\ visionOnlyF=visionOnlyF) load_end = time.time() print('Model Loading time %f' % ((load_end - load_start) / 60.0)) #histogram bins = 100 male_bucket = np.zeros([bins, bins]) fale_bucket = np.zeros([bins, bins]) male_dist_centre, fale_dist_centre = [], [] male_velocity, fale_velocity = [], [] male_pos = [np.hstack([trx['x'][0, male_ind], trx['y'][0, male_ind]])] fale_pos = [np.hstack([trx['x'][0, female_ind], trx['y'][0, female_ind]])] male_body_pos = [[trx['theta'][0,male_ind], \ trx['l_wing_ang'][0,male_ind],\ trx['r_wing_ang'][0,male_ind],\ trx['l_wing_len'][0,male_ind],\ trx['r_wing_len'][0,male_ind]]] fale_body_pos = [[trx['theta'][0,female_ind], \ trx['l_wing_ang'][0,female_ind],\ trx['r_wing_ang'][0,female_ind],\ trx['l_wing_len'][0,female_ind],\ trx['r_wing_len'][0,female_ind]]] male_motion, fale_motion = [], [] print('Simulation Start...\n') from tqdm import tqdm for counter, t in tqdm(enumerate(range(t0 + 1, t1))): xprev[:] = x yprev[:] = y thetaprev[:] = theta male_dist_centre.append([x[male_ind]-default_params['arena_center_x'],\ y[male_ind]-default_params['arena_center_y']]) fale_dist_centre.append([x[female_ind]-default_params['arena_center_x'],\ y[female_ind]-default_params['arena_center_y']]) if (t - t0 - 1) > burning: ## Simulate Male Model x, y, theta, a, l_wing_ang, r_wing_ang, \ l_wing_len, r_wing_len, _, \ male_vision_chamber, feat_motion, state = \ get_simulate_fly(male_model, state, t, trx,\ simulated_male_flies, feat_motion,\ x, y, theta, a, b, \ l_wing_ang, r_wing_ang,\ l_wing_len, r_wing_len,\ xprev, yprev, thetaprev, basesize, params, args.mtype,\ visionF=args.visionF, \ visionOnly=visionOnlyF,\ t_dim=t_dim) ## Simulate Female Model x, y, theta, a, l_wing_ang, r_wing_ang, \ l_wing_len, r_wing_len, _,\ female_vision_chamber, feat_motion, state =\ get_simulate_fly(female_model, state, t, trx, \ simulated_female_flies, feat_motion,\ x, y, theta, a, b, \ l_wing_ang, r_wing_ang,\ l_wing_len, r_wing_len,\ xprev, yprev, thetaprev, basesize, params, args.mtype,\ visionF=args.visionF, \ visionOnly=visionOnlyF, \ t_dim=t_dim) ## Real male Model x, y, theta, a, \ l_wing_ang, r_wing_ang, \ l_wing_len, r_wing_len, male_vision_chamber, feat_motion\ = get_real_fly(real_male_flies, \ motiondata, feat_motion,\ t, trx, x, y, theta, l_wing_ang, r_wing_ang,\ l_wing_len, r_wing_len,\ a, b, params) ## Real female Model x, y, theta, a, \ l_wing_ang, r_wing_ang, \ l_wing_len, r_wing_len, female_vision_chamber, feat_motion\ = get_real_fly(real_female_flies, \ motiondata, feat_motion,\ t, trx, x, y, theta, l_wing_ang, r_wing_ang,\ l_wing_len, r_wing_len,\ a, b, params) else: for fly in range(n_flies): x[fly] = trx['x'][t, fly] y[fly] = trx['y'][t, fly] theta[fly] = trx['theta'][t, fly] a[fly] = trx['a'][t, fly] l_wing_ang[fly] = trx['l_wing_ang'][t, fly] r_wing_ang[fly] = trx['r_wing_ang'][t, fly] l_wing_len[fly] = trx['l_wing_len'][t, fly] r_wing_len[fly] = trx['r_wing_len'][t, fly] if 'conv' in args.mtype: state[fly, :, :(t_dim - 1)] = state[fly, :, 1:] state[fly, :8, -1] = feat_motion[:, fly] state[fly, 8:, -1] = vc_data[t, fly, :] if visionOnlyF: state[fly, :, :(t_dim - 1)] = state[fly, :, 1:] state[fly, :, -1] = vc_data[t, fly, :] else: state[fly, :, :(t_dim - 1)] = state[fly, :, 1:] state[fly, :8, -1] = feat_motion[:, fly] feat_motion[:, fly] = motiondata[:, t, fly] if DEBUG == 2: mymotiondata[:,t,fly] = \ compute_motion(xprev[fly],yprev[fly],thetaprev[fly],x[fly], y[fly],theta[fly],a[fly],l_wing_ang[fly], r_wing_ang[fly],l_wing_len[fly],r_wing_len[fly],basesize,t, fly,params) male_velocity.append( [x[male_ind] - xprev[male_ind], y[male_ind] - yprev[male_ind]]) fale_velocity.append([ x[female_ind] - xprev[female_ind], y[female_ind] - yprev[female_ind] ]) male_pos.append(np.hstack([x[male_ind].copy(), y[male_ind].copy()])) male_motion.append(feat_motion[:, male_ind].copy()) male_body_pos.append(np.asarray([theta[male_ind], \ l_wing_ang[male_ind], r_wing_ang[male_ind], \ l_wing_len[male_ind], r_wing_len[male_ind]])) fale_pos.append(np.hstack([x[female_ind].copy(), y[female_ind].copy()])) fale_motion.append(feat_motion[:, female_ind].copy()) fale_body_pos.append(np.asarray([theta[female_ind], l_wing_ang[female_ind], r_wing_ang[female_ind], \ l_wing_len[female_ind], r_wing_len[female_ind]])) simtrx['x'][t, :] = x simtrx['y'][t, :] = y simtrx['theta'][t, :] = theta simtrx['a'][t, :] = a simtrx['b'][t, :] = b simtrx['l_wing_ang'][t, :] = l_wing_ang simtrx['r_wing_ang'][t, :] = r_wing_ang simtrx['l_wing_len'][t, :] = l_wing_len simtrx['r_wing_len'][t, :] = r_wing_len if DEBUG == 1 and (t - t0 - 1) > burning: tprev0 = np.maximum(t0 + 1, t - plottrxlen) tprev1 = t0 + plottrxlen tprev2 = np.maximum(t0 + burning, t - plottrxlen) tprev = np.maximum(t0 + 1, t - plottrxlen) for fly in range(n_flies): if 'Single' in sim_type and fly_single_ind == fly: htrx[fly].set_data(simtrx['x'][tprev1:t + 1, fly], simtrx['y'][tprev1:t + 1, fly]) elif 'LONG' in sim_type: if (t - t0 - 1) > burning: htrx[fly].set_data(simtrx['x'][tprev2:t + 1, fly], simtrx['y'][tprev2:t + 1, fly]) else: htrx[fly].set_data(simtrx['x'][tprev:t + 1, fly], simtrx['y'][tprev:t + 1, fly]) if 'Single' in sim_type: ax.plot(simtrx['x'][tprev0:tprev1,fly_single_ind],\ simtrx['y'][tprev0:tprev1,fly_single_ind],\ '-',color='thistle',linewidth=3) update_flies(hbodies, hflies, htexts, x, y, a, b, theta, l_wing_ang, r_wing_ang, l_wing_len, r_wing_len) plt.pause(.001) counter_plt.set_text('{:.2f}sec'.format(counter / default_params['FPS'])) #counter_plt = plt.annotate('{:.2f}sec'.format(counter / default_params['FPS']), # xy=[1024-55,params['bg'].shape[0]-45], # xytext=(0, 3), # 3 points vertical offset # textcoords="offset points", # ha='center', va='bottom', size=18, color='w') # plot scale bar plt.annotate( '{:.2f}ppm'.format(default_params['PPM'] * 10), xy=[55, params['bg'].shape[0] - 45], xytext=(0, 3), # 3 points vertical offset textcoords="offset points", ha='center', va='bottom', size=14, color='black') plt.plot([20, 20 + default_params['PPM'] * 10], [params['bg'].shape[0] - 40, params['bg'].shape[0] - 40], '-', color='black', linewidth=2.) if t % 1 == 0 and t < t1: #/10.0: os.makedirs('./figs/sim/%s/' % vpath, exist_ok=True) if 'nn' in params['mtype'] or 'conv' in params['mtype']: plt.savefig('./figs/sim/%s/%s_%s_' % (vpath, params['mtype'], sim_type)\ +str(args.h_dim)+'hid_tau%d_' % (50) \ +'vision%d_visionOnly%d' % (visionF,visionOnlyF) \ +'_epoch'+str(args.model_epoch)\ +'_%dbs_%s_%05d.png' % (args.batch_sz,args.dtype,t), format='png') elif 'lr' in params['mtype']: plt.savefig('./figs/sim/%s/lr_%s_' % (vpath, sim_type) +str(t_dim)\ +'tau_vision%d_visionOnly%d_%05d.png' % (visionF,visionOnlyF,t), \ format='png') #+str(t0)+'t0_'+str(t1)+'t1_%05d.png' % t,\ else: plt.savefig('./figs/all/data_1000frames_%5d.png' % t, format='png', bbox_inches='tight') mtype = args.mtype #ftag = str(t0)+'t0_'+str(t1)+'t1' if visionOnlyF: ftag += 'visionOnly_' + str(t0) + 't0_' + str( t1) + 't1_' + sim_type + '_testvideo%d_%s' % (testvideo_num, args.dtype) elif visionF: ftag += str(t0) + 't0_' + str( t1) + 't1_' + sim_type + '_testvideo%d_%s' % (testvideo_num, args.dtype) else: ftag += 'visionF0_' + str(t0) + 't0_' + str( t1) + 't1_' + sim_type + '_testvideo%d_%s' % (testvideo_num, args.dtype) print('ftag %s' % ftag) arena_radius = default_params['arena_radius'] male_motion = np.asarray(male_motion) fale_motion = np.asarray(fale_motion) male_body_pos = np.asarray(male_body_pos) fale_body_pos = np.asarray(fale_body_pos) #np.save('./trx/'+mtype+'_motion_male_'+ftag, simtrx) print('./trx/' + mtype + '_trx_' + ftag) sio.savemat('./trx/' + mtype + '_trx_' + ftag, simtrx) male_pos = np.asarray(male_pos) fale_pos = np.asarray(fale_pos) if bookkeepingF: np.save('./motion/' + vpath + '/' + mtype + '_motion_male_' + ftag, male_motion) np.save('./motion/' + vpath + '/' + mtype + '_motion_fale_' + ftag, fale_motion) np.save('./motion/' + vpath + '/' + mtype + '_position_male_' + ftag, male_pos) np.save('./motion/' + vpath + '/' + mtype + '_position_fale_' + ftag, fale_pos) np.save( './motion/' + vpath + '/' + mtype + '_body_position_male_' + ftag, male_body_pos) np.save( './motion/' + vpath + '/' + mtype + '_body_position_fale_' + ftag, fale_body_pos) male_velocity = np.asarray(male_velocity) fale_velocity = np.asarray(fale_velocity) male_velocity = np.sqrt(np.sum(male_velocity**2, axis=1)).flatten() fale_velocity = np.sqrt(np.sum(fale_velocity**2, axis=1)).flatten() moving_male_ind = (male_velocity > 1.0) moving_fale_ind = (fale_velocity > 1.0) male_velocity_ns = male_velocity[moving_male_ind] fale_velocity_ns = fale_velocity[moving_fale_ind] if bookkeepingF: np.save('./velocity/' + vpath + '/' + mtype + '_velocity_male_' + ftag, male_velocity) np.save('./velocity/' + vpath + '/' + mtype + '_velocity_fale_' + ftag, fale_velocity) np.save( './velocity/' + vpath + '/' + mtype + '_velocity_woStationary_male_ind_' + ftag, moving_male_ind) np.save( './velocity/' + vpath + '/' + mtype + '_velocity_woStationary_fale_ind_' + ftag, moving_fale_ind) np.save( './velocity/' + vpath + '/' + mtype + '_velocity_woStationary_male_' + ftag, male_velocity_ns) np.save( './velocity/' + vpath + '/' + mtype + '_velocity_woStationary_fale_' + ftag, fale_velocity_ns) if histoF: male_histo = histogram(male_velocity / 105, fname=mtype + '_velocity_male_histo_' + ftag, title='Velocity (Male)') fale_histo = histogram(fale_velocity / 105, fname=mtype + '_velocity_fale_histo_' + ftag, title='Velocity (Female)') np.save( './hist/' + vpath + '/' + mtype + '_velocity_male_histo_' + ftag, male_histo) np.save( './hist/' + vpath + '/' + mtype + '_velocity_fale_histo_' + ftag, fale_histo) male_histo_ns = histogram(male_velocity_ns / 105, fname=mtype + '_velocity_woStationary_male_histo_' + ftag, title='Velocity (Male)') fale_histo_ns = histogram(fale_velocity_ns / 105, fname=mtype + '_velocity_woStationary_fale_histo_' + ftag, title='Velocity (Female)') np.save( './hist/' + vpath + '/' + mtype + '_velocity_woStationary_male_histo_' + ftag, male_histo_ns) np.save( './hist/' + vpath + '/' + mtype + '_velocity_woStationary_fale_histo_' + ftag, fale_histo_ns) male_dist_centre = np.asarray(male_dist_centre) fale_dist_centre = np.asarray(fale_dist_centre) male_dist_centre = np.sqrt(np.sum(male_dist_centre**2, axis=1)).flatten() fale_dist_centre = np.sqrt(np.sum(fale_dist_centre**2, axis=1)).flatten() male_dist_centre_ = male_dist_centre / arena_radius fale_dist_centre_ = fale_dist_centre / arena_radius if bookkeepingF: np.save( './centredist/' + vpath + '/' + mtype + '_centredist_male_' + ftag, male_dist_centre) np.save( './centredist/' + vpath + '/' + mtype + '_centredist_fale_' + ftag, fale_dist_centre) male_dist_centre_ns = male_dist_centre[moving_male_ind] fale_dist_centre_ns = fale_dist_centre[moving_fale_ind] male_dist_centre_ns_ = male_dist_centre_ns / arena_radius fale_dist_centre_ns_ = fale_dist_centre_ns / arena_radius if bookkeepingF: np.save( './centredist/' + vpath + '/' + mtype + '_centredist_woStationary_male_' + ftag, male_dist_centre_ns) np.save( './centredist/' + vpath + '/' + mtype + '_centredist_woStationary_fale_' + ftag, fale_dist_centre_ns) if histoF: male_dist_histo = histogram(male_dist_centre_, fname=mtype + '_dist2centre_male_histo_' + ftag, title='Distance to Centre (Male)') fale_dist_histo = histogram(fale_dist_centre_, fname=mtype + '_dist2centre_fale_histo_' + ftag, title='Distance to Centre (Female)') np.save('./hist/' + vpath + '/' + mtype + '_dist_male_histo_' + ftag, male_dist_histo) np.save('./hist/' + vpath + '/' + mtype + '_dist_fale_histo_' + ftag, fale_dist_histo) male_dist_histo_ns = histogram( male_dist_centre_ns_, fname=mtype + '_dist2centre_woStationary_male_histo_' + ftag, title='Distance to Centre (Male) excluding stationary flies') fale_dist_histo_ns = histogram( fale_dist_centre_ns_, fname=mtype + '_dist2centre_woStationary_fale_histo_' + ftag, title='Distance to Centre (Female) excluding stationary flies') np.save( './hist/' + vpath + '/' + mtype + '_dist_woStationary_male_histo_' + ftag, male_dist_histo_ns) np.save( './hist/' + vpath + '/' + mtype + '_dist_woStationary_fale_histo_' + ftag, fale_dist_histo_ns) male_bucket = male_bucket / (t1 - t0) fale_bucket = fale_bucket / (t1 - t0) print(max(male_bucket.max(), fale_bucket.max())) return male_pos, fale_pos
def data4Kristin(mtype0, vpath, t_dim=50, gender='male', t0=0, tsim=30, visionF=1): fname0 = 'eyrun_simulate_data.mat' basepath = '/groups/branson/home/bransonk/behavioranalysis/code/SSRNN/SSRNN/Data/bowl/' parentpath = '/groups/branson/home/imd/Documents/janelia/research/fly_behaviour_sim/71g01' #ftag = ['velo', 'pos', 'bodyAng', 'wingang'] #ftag = ['velo', 'pos'] #for j in range(len(ftag)): pos_errs = [] pos_stds = [] print('%s' % (vpath)) matfile = basepath + vpath + fname0 (trx, motiondata, params, basesize) = load_eyrun_data(matfile) t1 = trx['x'].shape[0] - tsim male_ind, female_ind = gender_classify(basesize['majax']) pos_err_models = [] pos_std_models = [] err_tests = [] if 'const' == mtype0 or 'copy' == mtype0 or 'zero' == mtype0: fname = './metrics/' + vpath + '/' + mtype0 + '/' + mtype0 + '_visionF' + str( visionF) + '_' + str(t0) + 't0_' + str(t1) + 't1_30tsim.npy' err_test = np.load(fname) err_tests.append(err_test) else: for kk in range(10): if 'rnn' in mtype0: fname = './metrics/' + vpath + '/' + mtype0 + '/' + mtype0 + '_' + str( t0) + 't0_' + str(t1) + 't1_30tsim_150000epoch' elif 'skip' in mtype0: fname = './metrics/' + vpath + '/' + mtype0 + '/' + mtype0 + '_' + str( t0) + 't0_' + str(t1) + 't1_30tsim_150000epoch' else: fname = './metrics/' + vpath + '/' + mtype0 + '/' + mtype0 + '_visionF' + str( visionF) + '_' + str(t0) + 't0_' + str(t1) + 't1_30tsim' if ('pdb' in dtype or 'gmr91' in dtype) and 'lr' in mtype0: fname += '_200000epoch' err_test0 = np.load(fname + '_%dfold.npy' % kk) err_tests.append(err_test0) if 'lr' in mtype: for fold_k in range(10): err_stds = err_tests[fold_k].mean(axis=-1).std(axis=1) for feat in range(err_stds.shape[0]): for step in range(err_stds.shape[1]): eps = np.random.normal(0, err_stds[feat,step], \ size=(err_tests[0].shape[1], err_tests[0].shape[-1])) err_tests[fold_k][feat, :, step, :] += eps return err_tests
def main(argv=None): if argv == None: argv = sys.argv[1:] fname = 'eyrun_simulate_data.mat' args = parse_args(argv) args.y_dim = args.num_bin * args.num_mfeat video_list = video16_path[args.dtype] if args.plotF: colors = [ 'black', 'silver', 'red', 'green', 'deepskyblue', 'mediumpurple' ] models = ['const', 'zero', 'lr50', 'conv4_cat50', 'rnn50', 'skip50'] labels = ['CONST', 'HALT', 'LINEAR', 'CNN', 'RNN', 'HRNN'] vlowmax = [[2, 4.75], [5, 75], [0.1, 0.9], [0.04, 0.08], [0.04, 0.08]] plot_nstep_errors(models, args.dtype, video_list, t_dim=args.t_dim, gender=args.gender, t0=0, tsim=args.tsim, visionF=args.visionF, labels=labels, colors=colors, vlowmax=vlowmax) else: if args.mtype == 'lr50': ### LR ### save_path = './runs/linear_reg_' + str( args.t_dim) + 'tau/%s/model/weight_gender0' % args.dtype if not args.visionF: save_path = save_path + '_visionF0' male_model = np.load(save_path + '.npy') save_path = './runs/linear_reg_' + str( args.t_dim) + 'tau/%s/model/weight_gender1' % args.dtype if not args.visionF: save_path = save_path + '_visionF0' female_model = np.load(save_path + '.npy') for testvideo_num in range(0, len(video_list[TEST])): if testvideo_num == 1 or testvideo_num == 2: simulated_male_flies = np.arange(0, 9, 1) simulated_female_flies = np.arange(9, 19, 1) else: simulated_male_flies = np.arange(0, 10, 1) simulated_female_flies = np.arange(10, 20, 1) print('testvideo %d %s' % (testvideo_num, video_list[TEST][testvideo_num])) for ifold in range(10): real_flies_simulatePlan_RNNs(video_list[TEST][testvideo_num],\ male_model, female_model, \ simulated_male_flies, simulated_female_flies,\ monlyF=abs(1-args.visionF), ifold=ifold,\ tsim=args.tsim, mtype=args.mtype, t_dim=args.t_dim,\ num_bin=args.num_bin) elif args.mtype == 'nn4_cat50' or args.mtype == 'conv4_cat50': args.visionF = 1 vtype = 'full' model_epoch = 25000 if args.mtype == 'nn4_cat50' else 25000 batch_sz = 100 if args.mtype == 'nn4_cat50' else 32 from simulate_autoreg import model_selection for testvideo_num in range(0, len(video_list[TEST])): vpath = video_list[TEST][testvideo_num] print('testvideo %d %s' % (testvideo_num, vpath)) matfile = args.datapath + vpath + fname _, _, params, basesize = load_eyrun_data(matfile) if testvideo_num == 1 or testvideo_num == 2: simulated_male_flies = np.arange(0, 9, 1) simulated_female_flies = np.arange(9, 19, 1) else: simulated_male_flies = np.arange(0, 10, 1) simulated_female_flies = np.arange(10, 20, 1) params['mtype'] = args.mtype male_model, female_model = \ model_selection(None, None, params, visionF=args.visionF,\ model_epoch=model_epoch, t_dim=args.t_dim, vtype=vtype,\ batch_sz=batch_sz, mtype=args.mtype, dtype=args.dtype) for ifold in range(10): real_flies_simulatePlan_RNNs(video_list[TEST][testvideo_num],\ male_model, female_model, \ simulated_male_flies, simulated_female_flies,\ monlyF=abs(1-args.visionF), ifold=ifold, \ model_epoch=model_epoch, \ tsim=args.tsim, mtype=args.mtype,\ t_dim=args.t_dim, num_bins=args.num_bin) elif 'rnn' in args.mtype or 'skip' in args.mtype: ### LR MO ### model_epoch = 200000 # hardcoded number of epochs of training from simulate_rnn import model_selection for testvideo_num in range(0, len(video_list[TEST])): vpath = video_list[TEST][testvideo_num] print('testvideo %d/%d %s' % (testvideo_num, len(video_list[TEST]), vpath)) matfile = args.datapath + vpath + fname if testvideo_num == 1 or testvideo_num == 2: # hardcoded: male flies are 0...8 and female flies are 9...18 for video 1, 2 # seems like this won't work if we change video data set. TO FIX simulated_male_flies = np.arange(0, 9, 1) simulated_female_flies = np.arange(9, 19, 1) else: # hardcoded: male flies are 0...9 and female flies are 10...19 for video 0, 3 # seems like this won't work if we change video data set. TO FIX simulated_male_flies = np.arange(0, 10, 1) simulated_female_flies = np.arange(10, 20, 1) male_model, female_model, male_hiddens, female_hiddens = \ model_selection(args, None, None, \ args.videotype, args.mtype, model_epoch, \ args.h_dim, simulated_male_flies, simulated_female_flies, \ dtype=args.dtype, btype=args.btype) for ifold in range(10): real_flies_simulatePlan_RNNs(video_list[TEST][testvideo_num],\ male_model, female_model, \ simulated_male_flies, simulated_female_flies,\ male_hiddens, female_hiddens,\ model_epoch=model_epoch, \ monlyF=abs(1-args.visionF), ifold=ifold,\ tsim=args.tsim, mtype=args.mtype, \ t_dim=args.t_dim, btype=args.btype,\ num_bin=args.num_bin,\ gender=args.gender) elif args.mtype == 'zero': for testvideo_num in range(0, len(video_list[TEST])): vpath = video_list[TEST][testvideo_num] print('testvideo %d %s' % (testvideo_num, vpath)) baseline0_zero_nstep_prediction(vpath, tsim=args.tsim) elif args.mtype == 'const': for testvideo_num in range(0, len(video_list[TEST])): vpath = video_list[TEST][testvideo_num] print('testvideo %d %s' % (testvideo_num, vpath)) baseline0_constant_nstep_prediction(vpath, tsim=args.tsim) elif args.mtype == 'copy': for testvideo_num in range(0, len(video_list[TEST])): vpath = video_list[TEST][testvideo_num] print('testvideo %d %s' % (testvideo_num, vpath)) baseline1_constVel_nstep_prediction(vpath, tsim=agrs.tsim)
def plot_nstep_errors(models, dtype, video_list, t_dim=50, gender=0, t0=21224, tsim=15, \ visionF=1, labels=None, vlowmax=None, colors=['blue','red','green', 'magenta', 'purple', 'black']): if labels is None: labels = models fname0 = 'eyrun_simulate_data.mat' basepath = '/groups/branson/home/bransonk/behavioranalysis/code/SSRNN/SSRNN/Data/bowl/' parentpath = '/groups/branson/home/imd/Documents/janelia/research/fly_behaviour_sim/71g01' ftag = ['velo', 'pos', 'bodyAng', 'wingang'] for j in range(len(ftag)): pos_errs = [] pos_stds = [] #for testvideo_num, vpath in enumerate(video_list[TEST]): for testvideo_num in range(0, len(video_list[TEST])): vpath = video_list[TEST][testvideo_num] print('testvideo %d %s' % (testvideo_num, video_list[TEST][testvideo_num])) matfile = basepath + vpath + fname0 (trx, motiondata, params, basesize) = load_eyrun_data(matfile) t1 = trx['x'].shape[0] - tsim male_ind, female_ind = gender_classify(basesize['majax']) pos_err_models = [] pos_std_models = [] for mtype0 in models: if 'const' == mtype0 or 'copy' == mtype0 or 'zero' == mtype0: fname = args.basepath + '/metrics/' + vpath + '/' + mtype0 + '/' + mtype0 + '_visionF' + str( visionF) + '_' + str(t0) + 't0_' + str( t1) + 't1_30tsim.npy' err_test = np.load(fname) else: err_tests = [] for kk in range(10): if 'rnn' in mtype0: fname = args.basepath + '/metrics/' + vpath + '/' + mtype0 + '/' + mtype0 + '_' + str( t0) + 't0_' + str(t1) + 't1_30tsim_150000epoch' elif 'skip' in mtype0: fname = args.basepath + '/metrics/' + vpath + '/' + mtype0 + '/' + mtype0 + '_' + str( t0) + 't0_' + str(t1) + 't1_30tsim_100000epoch' else: fname = args.basepath + '/metrics/' + vpath + '/' + mtype0 + '/' + mtype0 + '_visionF' + str( visionF) + '_' + str(t0) + 't0_' + str( t1) + 't1_30tsim' if ('pdb' in dtype or 'gmr91' in dtype) and 'lr' in mtype0: fname += '_200000epoch' err_test0 = np.load(fname + '_%dfold.npy' % kk) #if dtype == 'gmr91': err_test0 = err_test0[:,:,:,1:] err_tests.append(err_test0) err_test0 = np.min(err_tests, axis=0) err_test = err_test0[:] print(fname) pos_err_tests = [] pos_std_tests = [] for i in range(1, tsim - 1): if gender == 0: pos_err_test = np.nanmean(err_test[j, :, i, male_ind]) pos_std_test = np.nanstd(err_test[j, :, i, male_ind]) else: pos_err_test = np.nanmean(err_test[j, :, i, female_ind]) pos_std_test = np.nanstd(err_test[j, :, i, female_ind]) pos_err_tests.append(pos_err_test) pos_std_tests.append(pos_std_test) pos_err_models.append(pos_err_tests) pos_std_models.append(pos_std_tests) pos_errs.append(pos_err_models) pos_stds.append(pos_std_models) pos_err_models = np.nanmean(pos_errs, axis=0) pos_std_models = np.nanmean(pos_stds, axis=0) xx = np.arange(2, tsim) plt.figure() ax = plt.axes([0, 0, 1, 1]) for i, pos_err_test in enumerate(pos_err_models): plt.errorbar(xx, pos_err_test, ls='-', color=colors[i], label=labels[i], lw=3, alpha=0.8) if 'body' in ftag[j]: plt.ylim([0, pos_err_models[1:].max()]) plt.xlabel('N-steps') plt.ylabel('Error rate') SMALL_SIZE = 22 matplotlib.rc('font', size=SMALL_SIZE) matplotlib.rc('axes', titlesize=SMALL_SIZE) ax.legend(fontsize=20, loc='upper center', bbox_to_anchor=(0.5, 1.25), ncol=3) os.makedirs('./figs/nstep/%s/' % dtype, exist_ok=True) plt.savefig('./figs/nstep/%s/eval_%dsteps_%s_gender%d_%s.pdf' \ % (dtype, tsim, ftag[j], gender, dtype), format='pdf', bbox_inches='tight') return
def baseline1_constVel_nstep_prediction(vpath, gender=1, t0=0, t1=None, tsim=15): DEBUG = 0 plottrxlen = 100 fname = 'eyrun_simulate_data.mat' basepath = '/groups/branson/home/bransonk/behavioranalysis/code/SSRNN/SSRNN/Data/bowl/' matfile = basepath + vpath + fname (trx, motiondata, params, basesize) = load_eyrun_data(matfile) binedges = params['binedges'] params['mtype'] = 'rnn' # initial pose print("TSIM: %d" % tsim) if t1 is None: t1 = trx['x'].shape[0] - tsim x = trx['x'][t0 + t_dim, :].copy() y = trx['y'][t0 + t_dim, :].copy() theta = trx['theta'][t0 + t_dim, :].copy() # even flies are simulated, odd are real n_flies = len(x) real_flies = np.arange(0, n_flies, 1) simulated_flies = [] b = basesize['minax'].copy() a = trx['a'][t0, :].copy() l_wing_ang = trx['l_wing_ang'][t0 + t_dim, :].copy() r_wing_ang = trx['r_wing_ang'][t0 + t_dim, :].copy() l_wing_len = trx['l_wing_len'][t0 + t_dim, :].copy() r_wing_len = trx['r_wing_len'][t0 + t_dim, :].copy() xprev = x.copy() yprev = y.copy() thetaprev = theta.copy() state = [None] * max(len(simulated_flies), len(real_flies)) feat_motion = motiondata[:, t0 + t_dim, :].copy() mymotiondata = np.zeros(motiondata.shape) vel_errors, pos_errors, theta_errors, wing_ang_errors, wing_len_errors \ = [], [], [], [], [] errors, acc_rates, loss_rates = [], [], [] print('Simulation Start...\n') #progress = tqdm(enumerate(range(t0+t_dim,t1))) progress = tqdm(enumerate(range(t0 + t_dim, t1, tsim))) for ii, t in progress: if ii >= 50: simtrx = {} simtrx['x'] = ((trx['x'][t,:] - trx['x'][t-1,:])[:,None] \ * np.tile(np.arange(tsim), [n_flies,1])).T\ + np.tile(trx['x'][t,:], [tsim,1]) simtrx['y'] = ((trx['y'][t,:] - trx['y'][t-1,:])[:,None] \ * np.tile(np.arange(tsim), [n_flies,1])).T\ + np.tile(trx['y'][t,:], [tsim,1]) simtrx['theta'] = ((trx['theta'][t,:] - trx['theta'][t-1,:])[:,None] \ * np.tile(np.arange(tsim), [n_flies,1])).T\ + np.tile(trx['theta'][t,:], [tsim,1]) simtrx['l_wing_ang'] = ((trx['l_wing_ang'][t,:] \ - trx['l_wing_ang'][t-1,:])[:,None] \ * np.tile(np.arange(tsim), [n_flies,1])).T\ + np.tile(trx['l_wing_ang'][t,:], [tsim,1]) simtrx['r_wing_ang'] = ((trx['r_wing_ang'][t,:] \ - trx['r_wing_ang'][t-1,:])[:,None] \ * np.tile(np.arange(tsim), [n_flies,1])).T\ + np.tile(trx['r_wing_ang'][t,:], [tsim,1]) simtrx['l_wing_len'] = ((trx['l_wing_len'][t,:] \ - trx['l_wing_len'][t-1,:])[:,None] \ * np.tile(np.arange(tsim), [n_flies,1])).T\ + np.tile(trx['l_wing_len'][t,:], [tsim,1]) simtrx['r_wing_len'] = ((trx['r_wing_len'][t,:] \ - trx['r_wing_len'][t-1,:])[:,None] \ * np.tile(np.arange(tsim), [n_flies,1])).T\ + np.tile(trx['r_wing_len'][t,:], [tsim,1]) vel_error, pos_error, theta_error, \ wing_ang_error, wing_len_error = [], [], [], [], [] for tt in range(1, tsim): #[1,3,5,10,15]: results = get_error(simtrx, trx, t, tt) vel_error.append(results[2]) pos_error.append(results[3]) theta_error.append(results[4]) wing_ang_error.append(results[5]) wing_len_error.append(results[6]) vel_error = np.asarray(vel_error) pos_error = np.asarray(pos_error) theta_error = np.asarray(theta_error) wing_ang_error = np.asarray(wing_ang_error) wing_len_error = np.asarray(wing_len_error) vel_errors.append(vel_error) pos_errors.append(pos_error) theta_errors.append(theta_error) wing_ang_errors.append(wing_ang_error) wing_len_errors.append(wing_len_error) progress.set_description(('VEL MSE: %f POSITION MSE : %f THETA MSE %f' \ + 'WING ANG MSE %f WING LEN MSE %f') % (np.mean(vel_error[0]), np.mean(pos_error[0]), np.mean(theta_error[0]),\ np.nanmean(wing_ang_error[0]), np.nanmean(wing_len_error[0]))) results = np.stack([ vel_errors, pos_errors, theta_errors, wing_ang_errors, wing_len_errors ]) os.makedirs('./metrics/%s/' % (vpath), exist_ok=True) os.makedirs('./metrics/%s/%s' % (vpath, mtype), exist_ok=True) fname = './metrics/' + vpath + '/' + mtype + '/' + mtype + '_visionF1_' + str( t0) + 't0_' + str(t1) + 't1_%dtsim' % tsim print(fname) np.save(fname, np.asarray(results)) print('Final Velocity Error %f' % (np.nanmean(vel_errors))) print('Final Position Error %f' % (np.nanmean(pos_errors))) print('Final Theta Error %f' % (np.nanmean(theta_errors))) print('Final Wing Ang Error %f' % (np.nanmean(wing_ang_errors))) print('Final Wing Len Error %f' % (np.nanmean(wing_len_errors))) #print('Accuracy %f Loss %f' % (np.mean(acc_rates), np.mean(loss_rates))) # end loop over frames pass
def real_flies_simulatePlan_RNNs(vpath, male_model, female_model,\ simulated_male_flies, simulated_female_flies,\ hiddens_male=None, hiddens_female=None, mtype='rnn', \ monlyF=0, plottrxlen=100, tsim=1, t0=0, t1=None,\ t_dim=7, genDataset=False, ifold=0, binwidth=2.0,\ num_hid=100, model_epoch=200000, btype='linear',\ num_bin=51,gender=0): print(mtype, monlyF, tsim) DEBUG = 0 fname = 'eyrun_simulate_data.mat' basepath = '/groups/branson/home/bransonk/behavioranalysis/code/SSRNN/SSRNN/Data/bowl/' matfile = basepath + vpath + fname (trx, motiondata, params, basesize) = load_eyrun_data(matfile) vision_matfile = basepath + vpath + 'movie-vision.mat' vc_data = load_vision(vision_matfile)[1:] if 'perc' in btype: binedges = np.load('./bins/percentile_%dbins.npy' % num_bin) params['binedges'] = binedges else: binedges = params['binedges'] male_ind, female_ind = gender_classify(basesize['majax']) params['mtype'] = mtype #initial pose print("TSIM: %d" % tsim) if t1 is None: t1 = trx['x'].shape[0] - tsim x = trx['x'][t0 + t_dim, :].copy() y = trx['y'][t0 + t_dim, :].copy() theta = trx['theta'][t0 + t_dim, :].copy() # even flies are simulated, odd are real n_flies = len(x) real_flies = np.arange(0, n_flies, 1) simulated_flies = [] b = basesize['minax'].copy() a = trx['a'][t0 + t_dim, :].copy() l_wing_ang = trx['l_wing_ang'][t0 + t_dim, :].copy() r_wing_ang = trx['r_wing_ang'][t0 + t_dim, :].copy() l_wing_len = trx['l_wing_len'][t0 + t_dim, :].copy() r_wing_len = trx['r_wing_len'][t0 + t_dim, :].copy() xprev = x.copy() yprev = y.copy() thetaprev = theta.copy() # simulated_male_flies = simulated_male_flies[:len(male_ind)] #simulated_female_flies = simulated_female_flies[:len(female_ind)] simulated_male_flies = np.arange(len(male_ind)) simulated_female_flies = np.arange(len(male_ind), len(male_ind) + len(female_ind)) if 'rnn' in mtype or 'skip' in mtype: hiddens_male = [male_model.initHidden(1, use_cuda=0) \ for i in range(len(simulated_male_flies))] hiddens_female = [female_model.initHidden(1, use_cuda=0)\ for i in range(len(simulated_female_flies))] simulated_flies = np.hstack([simulated_male_flies, simulated_female_flies]) NUM_FLY = len(simulated_male_flies) + len(simulated_female_flies) print('Number of flies : %d' % NUM_FLY) if 'rnn' in mtype or 'skip' in mtype: male_state = [None] * (len(male_ind)) female_state = [None] * (len(female_ind)) elif 'conv' in mtype: state = np.zeros((NUM_FLY, NUM_VFEAT + NUM_MFEAT, t_dim)) elif 'lr' in mtype or 'nn' in mtype: state = np.zeros((NUM_FLY, NUM_MFEAT, t_dim)) feat_motion = motiondata[:, t0 + t_dim, :].copy() mymotiondata = np.zeros(motiondata.shape) predictions_flies, flyvisions = [], [] vel_errors, pos_errors, theta_errors, wing_ang_errors, wing_len_errors \ = [], [], [], [], [] acc_rates, loss_rates = [], [] simtrx_numpys, dataset, dataset_frames = [], [], [] print('Simulation Start...\n') progress = tqdm(enumerate(range(t0 + t_dim, t1, tsim))) for ii, t in progress: print(ii, t) if 'rnn' in mtype or 'skip' in mtype: for t_j in range(t - tsim, t): x[:] = trx['x'][t_j, :] y[:] = trx['y'][t_j, :] theta[:] = trx['theta'][t_j, :] a[:] = trx['a'][t_j, :] l_wing_ang[:] = trx['l_wing_ang'][t_j, :] r_wing_ang[:] = trx['r_wing_ang'][t_j, :] l_wing_len[:] = trx['l_wing_len'][t_j, :] r_wing_len[:] = trx['r_wing_len'][t_j, :] feat_motion[:, :] = motiondata[:, t_j, :] # for fly_j in range(n_flies): # x[fly_j] = trx['x'][t_j,fly_j] # y[fly_j] = trx['y'][t_j,fly_j] # theta[fly_j] = trx['theta'][t_j,fly_j] # a[fly_j] = trx['a'][t_j,fly_j] # l_wing_ang[fly_j] = trx['l_wing_ang'][t_j,fly_j] # r_wing_ang[fly_j] = trx['r_wing_ang'][t_j,fly_j] # l_wing_len[fly_j] = trx['l_wing_len'][t_j,fly_j] # r_wing_len[fly_j] = trx['r_wing_len'][t_j,fly_j] # # feat_motion[:,fly_j] = motiondata[:,t_j,fly_j] xprev[:] = x yprev[:] = y thetaprev[:] = theta x, y, theta, a, l_wing_ang, r_wing_ang, \ l_wing_len, r_wing_len, \ hiddens_male, male_vision_chamber, feat_motion, _ = \ get_simulate_fly(male_model, male_state, t_j, trx,\ simulated_male_flies, feat_motion,\ x, y, theta, a, b, \ l_wing_ang, r_wing_ang,\ l_wing_len, r_wing_len,\ xprev, yprev, thetaprev, basesize, params, mtype,\ num_bin=num_bin, hiddens=hiddens_male) ## Simulate Female Model x, y, theta, a, l_wing_ang, r_wing_ang, \ l_wing_len, r_wing_len, \ hiddens_female, female_vision_chamber, feat_motion, _ =\ get_simulate_fly(female_model, female_state, t_j, trx,\ simulated_female_flies, feat_motion,\ x, y, theta, a, b, \ l_wing_ang, r_wing_ang,\ l_wing_len, r_wing_len,\ xprev, yprev, thetaprev,\ basesize, params, mtype,\ num_bin=num_bin, hiddens=hiddens_female) for flyi in range(len(real_flies)): fly = real_flies[flyi] x[fly] = trx['x'][t, fly].copy() y[fly] = trx['y'][t, fly].copy() theta[fly] = trx['theta'][t, fly].copy() a[fly] = trx['a'][t, fly].copy() l_wing_ang[fly] = trx['l_wing_ang'][t, fly].copy() r_wing_ang[fly] = trx['r_wing_ang'][t, fly].copy() l_wing_len[fly] = trx['l_wing_len'][t, fly].copy() r_wing_len[fly] = trx['r_wing_len'][t, fly].copy() # motiondata[:,t,fly] = corresponds to movement from t-1 to t feat_motion[:, fly] = motiondata[:, t, fly] if 'conv' in mtype: state = np.zeros((NUM_FLY, NUM_VFEAT + NUM_MFEAT, t_dim)) state[:, :8, :] = motiondata[:, t - 50:t, :].transpose(2, 0, 1) state[:, 8:, :] = vc_data[t - 50:t, :, :].transpose(1, 2, 0) #Start doing nstep (tsim) predictions at each time step t if 'rnn' in mtype or 'skip' in mtype: simtrx_curr, feat_motion, predictions, \ hiddens_male, hiddens_female, flyvisions\ = get_nstep_comparison_rnn(\ x, y, theta, a, b, \ l_wing_ang, r_wing_ang, \ l_wing_len, r_wing_len,\ NUM_FLY, trx, \ male_model, female_model,\ male_state, female_state, \ feat_motion, \ params, basesize, motiondata,\ tsim, t, simulated_flies,\ monlyF=monlyF,\ mtype=mtype, \ male_hiddens=hiddens_male,\ female_hiddens=hiddens_female,\ male_ind=male_ind,\ female_ind=female_ind, num_bin=num_bin) else: simtrx_curr, feat_motion, predictions, flyvisions\ = get_nstep_comparison(\ x, y, theta, a, b, \ l_wing_ang, r_wing_ang, \ l_wing_len, r_wing_len, \ NUM_FLY, trx, \ male_model, female_model, \ male_ind, female_ind,\ state, feat_motion, \ params, basesize, motiondata,\ tsim, t, simulated_flies, \ monlyF=monlyF, mtype=mtype, \ t_dim=t_dim, num_bin=num_bin) if genDataset: flyvisions = np.asarray(flyvisions) data = combine_vision_data(simtrx_curr, flyvisions, num_fly=NUM_FLY, num_burn=2) dataset.append(data) dataset_frames.append(t) simtrx_numpy = simtrx2numpy(simtrx_curr) simtrx_numpys.append(simtrx_numpy) if 1: vel_error, pos_error, theta_error, wing_ang_error, wing_len_error = [], [], [], [], [] for tt in range(1, tsim): #[1,3,5,10,15]: results = get_error(simtrx_curr, trx, t, tt) vel_error.append(results[2]) pos_error.append(results[3]) theta_error.append(results[4]) wing_ang_error.append(results[5]) wing_len_error.append(results[6]) if 0: loss, acc_rate = get_loss_change_motion(predictions, \ motiondata, t,\ gender) acc_rates.append(acc_rate) loss_rates.append(loss) progress.set_description('Accuracy : %f, Loss %f' % (acc_rate, loss)) vel_error = np.asarray(vel_error) pos_error = np.asarray(pos_error) theta_error = np.asarray(theta_error) wing_ang_error = np.asarray(wing_ang_error) wing_len_error = np.asarray(wing_len_error) vel_errors.append(vel_error) pos_errors.append(pos_error) theta_errors.append(theta_error) wing_ang_errors.append(wing_ang_error) wing_len_errors.append(wing_len_error) progress.set_description(('%d VEL MSE: %f POSITION MSE : %f THETA MSE %f' \ + 'WING ANG MSE %f WING LEN MSE %f') % (t, np.nanmean(vel_error[-1]), np.nanmean(pos_error[-1]),\ np.nanmean(theta_error[-1]), \ np.nanmean(wing_ang_error[-1]), \ np.nanmean(wing_len_error[-1]))) if 'rnn' in mtype or 'skip' in mtype: os.makedirs('./simtrx/%s/' % (vpath), exist_ok=True) os.makedirs('./simtrx/%s/%s' % (vpath, mtype), exist_ok=True) np.save('./simtrx/'+vpath+'/'+mtype+'/'+mtype+'_gender'\ +str(gender)+'_'+str(num_hid)+'hid_'+str(t0)+'t0_'\ +str(t1)+'t1_%dtsim_%s_%depoch' % (tsim, btype, model_epoch) + str(ifold), \ np.asarray(simtrx_numpys)) elif 'lr' in mtype: os.makedirs('./simtrx/%s/' % (vpath), exist_ok=True) os.makedirs('./simtrx/%s/%s' % (vpath, mtype), exist_ok=True) np.save('./simtrx/'+vpath+'/'+mtype+'/'+mtype+'_gender'\ +str(gender)+'_'+str(t0)+'t0_'+str(t1)\ +'t1_%dtsim' % tsim + str(ifold), \ np.asarray(simtrx_numpys)) if genDataset: os.makedirs('./fakedata/%s/' % (vpath), exist_ok=True) os.makedirs('./fakedata/%s/%s' % (vpath, mtype), exist_ok=True) if 'lr' in mtype: ffname = './fakedata/'+vpath+'/'+mtype+'/'+mtype+'_gender'\ +str(gender)+'_'+str(t0)\ +'t0_'+str(t1)+'t1_%dtsim' % tsim np.save(ffname, np.asarray(dataset)) print('Data Generated Path: %s' % ffname) else: np.save('./fakedata/'+vpath+'/'+mtype+'/'+mtype+'_gender'\ +str(gender)+'_'+str(num_hid)+'hid_'+str(t0)\ +'t0_'+str(t1)+'t1_%dtsim_%s_%depoch' % (tsim, btype, model_epoch), \ np.asarray(dataset)) np.save('./fakedata/'+vpath+'/frame_index_'\ +str(t0) +'t0_'+str(t1)+'t1_%dtsim' % (tsim), \ np.asarray(dataset_frames)) visionF = 1 - int(monlyF) results = np.stack([ vel_errors, pos_errors, theta_errors, wing_ang_errors, wing_len_errors ]) os.makedirs('%s/metrics/%s/' % (args.outdir, vpath), exist_ok=True) os.makedirs('%s/metrics/%s/%s' % (args.outdir, vpath, mtype), exist_ok=True) if 'rnn' in mtype or 'skip' in mtype: fname = args.outdir + '/metrics/' + vpath + '/' + mtype + '/' + mtype + '_' + str( t0) + 't0_' + str(t1) + 't1_%dtsim_%s_%depoch_%dfold' % ( tsim, btype, model_epoch, ifold) else: fname = args.outdir + '/metrics/' + vpath + '/' + mtype + '/' + mtype + '_visionF' + str( visionF) + '_' + str(t0) + 't0_' + str( t1) + 't1_%dtsim_%depoch_%dfold' % (tsim, model_epoch, ifold) print(fname) np.save(fname, np.asarray(results)) print('Final Velocity Error %f' % (np.nanmean(vel_errors))) print('Final Position Error %f' % (np.nanmean(pos_errors))) print('Final Theta Error %f' % (np.nanmean(theta_errors))) print('Final Wing Ang Error %f' % (np.nanmean(wing_ang_errors))) print('Final Wing Len Error %f' % (np.nanmean(wing_len_errors))) return simtrx_curr