def test_model(): torch.manual_seed(0) np.random.seed(0) T = 5 x_grids = np.array([0.0, 10.0]) y_grids = np.array([0.0, 8.0]) bounds = np.array([[0.0, 10.0], [0.0, 8.0], [0.0, 10.0], [0.0, 8.0]]) data = np.array([[1.0, 1.0, 1.0, 6.0], [3.0, 6.0, 8.0, 6.0], [4.0, 7.0, 8.0, 5.0], [6.0, 7.0, 5.0, 6.0], [8.0, 2.0, 6.0, 1.0]]) data = torch.tensor(data, dtype=torch.float64) K = 3 D = 4 M = 0 obs = GPObservation(K=K, D=D, x_grids=x_grids, y_grids=y_grids, bounds=bounds, train_rs=True) correct_kerneldist_gg = torch.tensor( [[0., 0., 64., 64., 100., 100., 164., 164.], [0., 0., 64., 64., 100., 100., 164., 164.], [64., 64., 0., 0., 164., 164., 100., 100.], [64., 64., 0., 0., 164., 164., 100., 100.], [100., 100., 164., 164., 0., 0., 64., 64.], [100., 100., 164., 164., 0., 0., 64., 64.], [164., 164., 100., 100., 64., 64., 0., 0.], [164., 164., 100., 100., 64., 64., 0., 0.]], dtype=torch.float64) assert torch.all(torch.eq(correct_kerneldist_gg, obs.kernel_distsq_gg)), obs.kernel_distsq_gg log_prob_nocache = obs.log_prob(data) print("log_prob_nocache = {}".format(log_prob_nocache)) kernel_distsq_xg_a = kernel_distsq_doubled(data[:-1, 0:2], obs.inducing_points) kernel_distsq_xg_b = kernel_distsq_doubled(data[:-1, 2:4], obs.inducing_points) correct_kernel_distsq_xg_a = torch.tensor( [[2., 2., 50., 50., 82., 82., 130., 130.], [2., 2., 50., 50., 82., 82., 130., 130.], [45., 45., 13., 13., 85., 85., 53., 53.], [45., 45., 13., 13., 85., 85., 53., 53.], [65., 65., 17., 17., 85., 85., 37., 37.], [65., 65., 17., 17., 85., 85., 37., 37.], [85., 85., 37., 37., 65., 65., 17., 17.], [85., 85., 37., 37., 65., 65., 17., 17.]], dtype=torch.float64) assert torch.all(torch.eq(correct_kernel_distsq_xg_a, kernel_distsq_xg_a)), kernel_distsq_xg_a memory_kwargs = dict(kernel_distsq_xg_a=kernel_distsq_xg_a, kernel_distsq_xg_b=kernel_distsq_xg_b) log_prob = obs.log_prob(data, **memory_kwargs) print("log_prob = {}".format(log_prob)) assert torch.all(torch.eq(log_prob_nocache, log_prob)) Sigma_a, A_a = obs.get_gp_cache(data[:-1, 0:2], 0, **memory_kwargs) Sigma_b, A_b = obs.get_gp_cache(data[:-1, 2:4], 1, **memory_kwargs) memory_kwargs_2 = dict(Sigma_a=Sigma_a, A_a=A_a, Sigma_b=Sigma_b, A_b=A_b) print("calculating log prob 2...") log_prob2 = obs.log_prob(data, **memory_kwargs_2) assert torch.all(torch.eq(log_prob, log_prob2)) model = HMM(K=K, D=D, M=M, transition="stationary", observation=obs) model.observation.mus_init = data[0] * torch.ones( K, D, dtype=torch.float64) # fit losses, opt = model.fit(data, optimizer=None, method='adam', num_iters=100, lr=0.01, pbar_update_interval=10, **memory_kwargs) plt.figure() plt.plot(losses) plt.show() # most-likely-z print("Most likely z...") z = model.most_likely_states(data, **memory_kwargs) # prediction print("0 step prediction") if data.shape[0] <= 1000: data_to_predict = data else: data_to_predict = data[-1000:] x_predict = k_step_prediction_for_gpmodel(model, z, data_to_predict, **memory_kwargs) x_predict_err = np.mean(np.abs(x_predict - data_to_predict.numpy()), axis=0) print("2 step prediction") x_predict_2 = k_step_prediction(model, z, data_to_predict, k=2) x_predict_2_err = np.mean(np.abs(x_predict_2 - data_to_predict[2:].numpy()), axis=0) # samples sample_T = 5 sample_z, sample_x = model.sample(sample_T)
def test_model(): torch.manual_seed(0) np.random.seed(0) T = 5 x_grids = np.array([0.0, 5.0, 10.0]) y_grids = np.array([0.0, 4.0, 8.0]) data = np.array([[1.0, 1.0, 1.0, 6.0], [3.0, 6.0, 8.0, 6.0], [4.0, 7.0, 8.0, 5.0], [6.0, 7.0, 5.0, 6.0], [8.0, 2.0, 6.0, 1.0]]) data = torch.tensor(data, dtype=torch.float64) def toy_feature_vec_func(s): """ :param s: self, (T, 2) :param o: other, (T, 2) :return: features, (T, Df, 2) """ corners = torch.tensor([[0, 0], [0, 8], [10, 0], [10, 8]], dtype=torch.float64) return feature_direction_vec(s, corners) K = 3 D = 4 M = 0 Df = 4 bounds = np.array([[0.0, 10.0], [0.0, 8.0], [0.0, 10.0], [0.0, 8.0]]) tran = LinearGridTransformation(K=K, D=D, x_grids=x_grids, y_grids=y_grids, Df=Df, feature_vec_func=toy_feature_vec_func) obs = ARTruncatedNormalObservation(K=K, D=D, M=0, lags=1, bounds=bounds, transformation=tran) model = HMM(K=K, D=D, M=M, transition="stationary", observation=obs) model.observation.mus_init = data[0] * torch.ones( K, D, dtype=torch.float64) # calculate memory gridpoints_idx_a = tran.get_gridpoints_idx_for_batch(data[:-1, 0:2]) gridpoints_idx_b = tran.get_gridpoints_idx_for_batch(data[:-1, 2:4]) gridpoints_a = tran.get_gridpoints_for_batch(gridpoints_idx_a) gridpoints_b = tran.get_gridpoints_for_batch(gridpoints_idx_b) feature_vecs_a = toy_feature_vec_func(data[:-1, 0:2]) feature_vecs_b = toy_feature_vec_func(data[:-1, 2:4]) gridpoints_idx = (gridpoints_idx_a, gridpoints_idx_b) gridpoints = (gridpoints_a, gridpoints_b) feature_vecs = (feature_vecs_a, feature_vecs_b) # fit losses, opt = model.fit(data, optimizer=None, method='adam', num_iters=100, lr=0.01, pbar_update_interval=10, gridpoints=gridpoints, gridpoints_idx=gridpoints_idx, feature_vecs=feature_vecs) plt.figure() plt.plot(losses) plt.show() # most-likely-z print("Most likely z...") z = model.most_likely_states(data, gridpoints_idx=gridpoints_idx, feature_vecs=feature_vecs) # prediction print("0 step prediction") if data.shape[0] <= 1000: data_to_predict = data else: data_to_predict = data[-1000:] x_predict = k_step_prediction_for_lineargrid_model( model, z, data_to_predict, gridpoints_idx=gridpoints_idx, feature_vecs=feature_vecs) x_predict_err = np.mean(np.abs(x_predict - data_to_predict.numpy()), axis=0) print("2 step prediction") x_predict_2 = k_step_prediction(model, z, data_to_predict, k=2) x_predict_2_err = np.mean(np.abs(x_predict_2 - data_to_predict[2:].numpy()), axis=0) # samples sample_T = 5 sample_z, sample_x = model.sample(sample_T)
def main(job_name, cuda_num, data_type, downsample_n, filter_traj, load_model, load_model_dir, load_opt_dir, animal, transition, sticky_alpha, sticky_kappa, k, x_grids, y_grids, n_x, n_y, rs, train_rs, train_vs, train_model, pbar_update_interval, video_clips, prop_start_end, held_out_proportion, torch_seed, np_seed, list_of_num_iters, ckpts_not_to_save, list_of_lr, list_of_k_steps, sample_t, quiver_scale): if job_name is None: raise ValueError("Please provide the job name.") cuda_num = int(cuda_num) device = torch.device( "cuda:{}".format(cuda_num) if torch.cuda.is_available() else "cpu") print("Using device {} \n\n".format(device)) assert animal in ['both', 'virgin', 'mother'], animal K = k sample_T = sample_t rs = float(rs) if rs else None video_clip_start, video_clip_end = [ float(x) for x in video_clips.split(",") ] start, end = [float(x) for x in prop_start_end.split(",")] list_of_num_iters = [int(x) for x in list_of_num_iters.split(",")] list_of_lr = [float(x) for x in list_of_lr.split(",")] # TODO: fix for no k_steps, k > 0 list_of_k_steps = [int(x) for x in list_of_k_steps.split(",") ] if list_of_k_steps else [] assert len(list_of_num_iters) == len( list_of_lr ), "Length of list_of_num_iters must match length of list_of_lr." for lr in list_of_lr: if lr > 1: raise ValueError("Learning rate should not be larger than 1!") ckpts_not_to_save = [int(x) for x in ckpts_not_to_save.split(',') ] if ckpts_not_to_save else [] repo = git.Repo( '.', search_parent_directories=True) # SocialBehaviorectories=True) repo_dir = repo.working_tree_dir # SocialBehavior torch.manual_seed(torch_seed) np.random.seed(np_seed) ########################## data ######################## if data_type == 'full': data_dir = repo_dir + '/SocialBehaviorptc/data/trajs_all' traj = joblib.load(data_dir) traj = traj[int(36000 * video_clip_start):int(36000 * video_clip_end)] elif data_type == 'selected_010_virgin': assert animal == 'virgin', "animal much be 'virgin', but got {}.".format( animal) data_dir = repo_dir + '/SocialBehaviorptc/data/traj_010_virgin_selected' traj = joblib.load(data_dir) T = len(traj) traj = traj[int(T * start):int(T * end)] else: raise ValueError("unsupported data type: {}".format(data_type)) if animal == 'virgin': traj = traj[:, 0:2] elif animal == 'mother': traj = traj[:, 2:4] traj = downsample(traj, downsample_n) if filter_traj: traj = filter_traj_by_speed(traj, q1=0.99, q2=0.99) data = torch.tensor(traj, dtype=torch.float64, device=device) assert 0 <= held_out_proportion <= 0.4, \ "held_out-portion should be between 0 and 0.4 (inclusive), but is {}".format(held_out_proportion) T = data.shape[0] breakpoint = int(T * (1 - held_out_proportion)) training_data = data[:breakpoint] valid_data = data[breakpoint:] if sample_T is None: sample_T = training_data.shape[0] ######################### model #################### # model D = data.shape[1] assert D == 4 or D == 2, D M = 0 if load_model: print("Loading the model from ", load_model_dir) model = joblib.load(load_model_dir) obs = model.observation K = model.K assert D == model.D, "D = {}, model.D = {}".format(D, model.D) n_x = len(obs.x_grids) - 1 n_y = len(obs.y_grids) - 1 bounds = obs.bounds else: print("Creating the model...") if D == 4: bounds = np.array([[ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX], [ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX]]) else: bounds = np.array([[ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX]]) # grids if x_grids is None: x_grid_gap = (ARENA_XMAX - ARENA_XMIN) / n_x x_grids = np.array( [ARENA_XMIN + i * x_grid_gap for i in range(n_x + 1)]) else: x_grids = np.array([float(x) for x in x_grids.split(",")]) n_x = len(x_grids) - 1 if y_grids is None: y_grid_gap = (ARENA_YMAX - ARENA_YMIN) / n_y y_grids = np.array( [ARENA_YMIN + i * y_grid_gap for i in range(n_y + 1)]) else: y_grids = np.array([float(x) for x in y_grids.split(",")]) n_y = len(y_grids) - 1 mus_init = training_data[0] * torch.ones( K, D, dtype=torch.float64, device=device) if animal == 'both': obs = GPObservation(K=K, D=D, mus_init=mus_init, x_grids=x_grids, y_grids=y_grids, bounds=bounds, rs=rs, train_rs=train_rs, train_vs=train_vs, device=device) else: obs = GPObservationSingle(K=K, D=D, mus_init=mus_init, x_grids=x_grids, y_grids=y_grids, bounds=bounds, rs=rs, train_rs=train_rs, device=device) if transition == 'sticky': transition_kwargs = dict(alpha=sticky_alpha, kappa=sticky_kappa) elif transition == 'grid': transition_kwargs = dict(x_grids=x_grids, y_grids=y_grids) else: transition_kwargs = None model = HMM(K=K, D=D, M=M, transition=transition, observation=obs, transition_kwargs=transition_kwargs, device=device) # save experiment params exp_params = { "job_name": job_name, "data_type": data_type, 'downsample_n': downsample_n, "filter_traj": filter_traj, "load_model": load_model, "load_model_dir": load_model_dir, "load_opt_dir": load_opt_dir, "animal": animal, "transition": transition, "sticky_alpha": sticky_alpha, "sticky_kappa": sticky_kappa, "K": K, "x_grids": x_grids, "y_grids": y_grids, "n_x": n_x, "n_y": n_y, "rs": rs, "train_rs": train_rs, "train_vs": train_vs, "train_model": train_model, "pbar_update_interval": pbar_update_interval, "video_clip_start": video_clip_start, "video_clip_end": video_clip_end, "start_percentage": start, "end_percentage": end, "held_out_proportion": held_out_proportion, "torch_seed": torch_seed, "np_seed": np_seed, "list_of_num_iters": list_of_num_iters, "list_of_lr": list_of_lr, "list_of_k_steps": list_of_k_steps, "ckpts_not_to_save": ckpts_not_to_save, "sample_T": sample_T, "quiver_scale": quiver_scale } print("Experiment params:") print(exp_params) rslt_dir = addDateTime("rslts/gp/{}/{}".format(animal, job_name)) rslt_dir = os.path.join(repo_dir, rslt_dir) if not os.path.exists(rslt_dir): os.makedirs(rslt_dir) print("Making result directory...") print("Saving exp_params to rlst_dir: ", rslt_dir) with open(rslt_dir + "/exp_params.json", "w") as f: json.dump(exp_params, f, indent=4, cls=NumpyEncoder) # compute memory if transition == "grid": print("Computing transition memory...") joint_grid_idx = model.transition.get_grid_idx(training_data[:-1]) transition_memory_kwargs = dict(joint_grid_idx=joint_grid_idx) valid_joint_grid_idx = model.transition.get_grid_idx(valid_data[:-1]) valid_data_transition_memory_kwargs = dict( joint_grid_idx=valid_joint_grid_idx) else: transition_memory_kwargs = None valid_data_transition_memory_kwargs = None print("Computing observation memory...") def get_memory_kwargs(data, train_rs): if data is None or data.shape[0] == 0: return {} if animal == 'both': kernel_distsq_xg_a = kernel_distsq(data[:-1, 0:2], obs.inducing_points) kernel_distsq_xg_b = kernel_distsq(data[:-1, 2:4], obs.inducing_points) kernel_distsq_dict = dict(kernel_distsq_xg_a=kernel_distsq_xg_a, kernel_distsq_xg_b=kernel_distsq_xg_b) else: kernel_distsq_xg = kernel_distsq(data[:-1], obs.inducing_points) kernel_distsq_dict = dict(kernel_distsq_xg=kernel_distsq_xg) if train_rs: return kernel_distsq_dict else: if animal == 'both': Sigma_a, A_a = obs.get_gp_cache(data[:-1, 0:2], 0, **kernel_distsq_dict) Sigma_b, A_b = obs.get_gp_cache(data[:-1, 2:4], 1, **kernel_distsq_dict) return dict(Sigma_a=Sigma_a, A_a=A_a, Sigma_b=Sigma_b, A_b=A_b) else: Sigma, A = obs.get_gp_cache(data[:-1], **kernel_distsq_dict) return dict(Sigma=Sigma, A=A) memory_kwargs = get_memory_kwargs(training_data, train_rs) valid_data_memory_kwargs = get_memory_kwargs(valid_data, train_rs) log_prob = model.log_likelihood( training_data, transition_memory_kwargs=transition_memory_kwargs, **memory_kwargs) print("log_prob = {}".format(log_prob)) ##################### training ############################ if train_model: print("start training") list_of_losses = [] if load_opt_dir != "": opt = joblib.load(load_opt_dir) else: opt = None for i, (num_iters, lr) in enumerate(zip(list_of_num_iters, list_of_lr)): training_losses, opt, valid_losses, rs_s = \ model.fit(training_data, optimizer=opt, method='adam', num_iters=num_iters, lr=lr, pbar_update_interval=pbar_update_interval, valid_data=valid_data, transition_memory_kwargs=transition_memory_kwargs, valid_data_transition_memory_kwargs=valid_data_transition_memory_kwargs, valid_data_memory_kwargs=valid_data_memory_kwargs, **memory_kwargs) list_of_losses.append(training_losses) checkpoint_dir = rslt_dir + "/checkpoint_{}".format(i) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) print("Creating checkpoint_{} directory...".format(i)) # save model and opt joblib.dump(model, checkpoint_dir + "/model") joblib.dump(opt, checkpoint_dir + "/optimizer") # save losses losses = dict(training_loss=training_losses, valid_loss=valid_losses) joblib.dump(losses, checkpoint_dir + "/losses") plt.figure() plt.plot(training_losses) plt.title("training loss") plt.savefig(checkpoint_dir + "/training_losses.jpg") plt.close() plt.figure() plt.plot(valid_losses) plt.title("validation loss") plt.savefig(checkpoint_dir + "/valid_losses.jpg") plt.close() rs_s = np.array(rs_s) plt.figure(figsize=(10, 5)) plt.subplot(1, 2, 1) plt.title("lengthscale(x)") for k in range(K): plt.plot(rs_s[:, k, 0], label='k={}'.format(k)) plt.legend() plt.subplot(1, 2, 2) plt.title("lengthscale(y)") for k in range(K): plt.plot(rs_s[:, k, 1], label='k={}'.format(k)) plt.legend() plt.savefig(checkpoint_dir + "/observation_rs.jpg") plt.close() # save rest if i in ckpts_not_to_save: print("ckpt {}: skip!\n".format(i)) continue with torch.no_grad(): rslt_saving(rslt_dir=checkpoint_dir, model=model, data=training_data, animal=animal, memory_kwargs=memory_kwargs, list_of_k_steps=list_of_k_steps, sample_T=sample_T, quiver_scale=quiver_scale, valid_data=valid_data, transition_memory_kwargs=transition_memory_kwargs, valid_data_transition_memory_kwargs= valid_data_transition_memory_kwargs, valid_data_memory_kwargs=valid_data_memory_kwargs, device=device) else: # only save the results rslt_saving(rslt_dir=rslt_dir, model=model, data=training_data, animal=animal, memory_kwargs=memory_kwargs, list_of_k_steps=list_of_k_steps, sample_T=sample_T, quiver_scale=quiver_scale, valid_data=valid_data, transition_memory_kwargs=transition_memory_kwargs, valid_data_transition_memory_kwargs= valid_data_transition_memory_kwargs, valid_data_memory_kwargs=valid_data_memory_kwargs, device=device) print("Finish running!")
masks_a, masks_b = tran.get_masks(data[:-1]) momentum_vecs_a = get_momentum_in_batch(data[:-1, 0:2], lags=momentum_lags, weights=momentum_weights) momentum_vecs_b = get_momentum_in_batch(data[:-1, 2:4], lags=momentum_lags, weights=momentum_weights) feature_vecs_a = feature_vec_func(data[:-1, 0:2], data[:-1, 2:4]) feature_vecs_b = feature_vec_func(data[:-1, 2:4], data[:-1, 0:2]) m_kwargs_a = dict(momentum_vecs=momentum_vecs_a, feature_vecs=feature_vecs_a) m_kwargs_b = dict(momentum_vecs=momentum_vecs_b, feature_vecs=feature_vecs_b) # observation obs = ARTruncatedNormalObservation(K=K, D=D, M=M, lags=momentum_lags, bounds=bounds, transformation=tran) # model model = HMM(K=K, D=D, M=M, observation=obs) log_prob = model.log_likelihood(data, masks=(masks_a, masks_b), memory_kwargs_a=m_kwargs_a, memory_kwargs_b=m_kwargs_b) log_prob_2 = model.log_likelihood(data) assert torch.eq(log_prob, log_prob_2) print(log_prob) # training print("start training...") num_iters = 10 losses, opt = model.fit(data, num_iters=num_iters, lr=0.001, masks=(masks_a, masks_b), memory_kwargs_a=m_kwargs_a, memory_kwargs_b=m_kwargs_b) # sampling print("start sampling")
def main(job_name, cuda_num, downsample_n, filter_traj, gp_version, load_model, load_model_dir, load_opt_dir, transition, sticky_alpha, sticky_kappa, acc_factor, k, x_grids, y_grids, n_x, n_y, rs_factor, rs, train_rs, train_model, pbar_update_interval, video_clips, held_out_proportion, torch_seed, np_seed, list_of_num_iters, ckpts_not_to_save, list_of_lr, list_of_k_steps, sample_t, quiver_scale): if job_name is None: raise ValueError("Please provide the job name.") cuda_num = int(cuda_num) device = torch.device( "cuda:{}".format(cuda_num) if torch.cuda.is_available() else "cpu") print("Using device {} \n\n".format(device)) K = k sample_T = sample_t rs_factor = np.array([float(x) for x in rs_factor.split(",")]) if rs_factor[0] == 0 and rs_factor[1] == 0: rs_factor = None rs = float(rs) video_clip_start, video_clip_end = [ float(x) for x in video_clips.split(",") ] list_of_num_iters = [int(x) for x in list_of_num_iters.split(",")] list_of_lr = [float(x) for x in list_of_lr.split(",")] list_of_k_steps = [int(x) for x in list_of_k_steps.split(",")] assert len(list_of_num_iters) == len( list_of_lr ), "Length of list_of_num_iters must match length of list_of_lr." for lr in list_of_lr: if lr > 1: raise ValueError("Learning rate should not be larger than 1!") ckpts_not_to_save = [int(x) for x in ckpts_not_to_save.split(',') ] if ckpts_not_to_save else [] repo = git.Repo( '.', search_parent_directories=True) # SocialBehaviorectories=True) repo_dir = repo.working_tree_dir # SocialBehavior torch.manual_seed(torch_seed) np.random.seed(np_seed) ########################## data ######################## data_dir = repo_dir + '/SocialBehaviorptc/data/trajs_all' trajs = joblib.load(data_dir) traj = trajs[int(36000 * video_clip_start):int(36000 * video_clip_end)] traj = downsample(traj, downsample_n) if filter_traj: traj = filter_traj_by_speed(traj, q1=0.99, q2=0.99) data = torch.tensor(traj, dtype=torch.float64, device=device) assert 0 <= held_out_proportion <= 0.4, \ "held_out-portion should be between 0 and 0.4 (inclusive), but is {}".format(held_out_proportion) T = data.shape[0] breakpoint = int(T * (1 - held_out_proportion)) training_data = data[:breakpoint] valid_data = data[breakpoint:] ######################### model #################### # model D = 4 M = 0 Df = 4 if load_model: print("Loading the model from ", load_model_dir) model = joblib.load(load_model_dir) tran = model.observation.transformation K = model.K n_x = len(tran.x_grids) - 1 n_y = len(tran.y_grids) - 1 acc_factor = tran.acc_factor else: print("Creating the model...") bounds = np.array([[ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX], [ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX]]) # grids if x_grids is None: x_grid_gap = (ARENA_XMAX - ARENA_XMIN) / n_x x_grids = np.array( [ARENA_XMIN + i * x_grid_gap for i in range(n_x + 1)]) else: x_grids = np.array([float(x) for x in x_grids.split(",")]) n_x = len(x_grids) - 1 if y_grids is None: y_grid_gap = (ARENA_YMAX - ARENA_YMIN) / n_y y_grids = np.array( [ARENA_YMIN + i * y_grid_gap for i in range(n_y + 1)]) else: y_grids = np.array([float(x) for x in y_grids.split(",")]) n_y = len(y_grids) - 1 if acc_factor is None: acc_factor = downsample_n * 10 tran = GPGridTransformation(K=K, D=D, x_grids=x_grids, y_grids=y_grids, Df=Df, feature_vec_func=f_corner_vec_func, acc_factor=acc_factor, rs_factor=rs_factor, rs=None, train_rs=train_rs, device=device, version=gp_version) obs = ARTruncatedNormalObservation(K=K, D=D, M=M, lags=1, bounds=bounds, transformation=tran, device=device) if transition == 'sticky': transition_kwargs = dict(alpha=sticky_alpha, kappa=sticky_kappa) else: transition_kwargs = None model = HMM(K=K, D=D, M=M, transition=transition, observation=obs, transition_kwargs=transition_kwargs, device=device) model.observation.mus_init = training_data[0] * torch.ones( K, D, dtype=torch.float64, device=device) # save experiment params exp_params = { "job_name": job_name, 'downsample_n': downsample_n, "filter_traj": filter_traj, "load_model": load_model, "gp_version": gp_version, "load_model_dir": load_model_dir, "load_opt_dir": load_opt_dir, "transition": transition, "sticky_alpha": sticky_alpha, "sticky_kappa": sticky_kappa, "acc_factor": acc_factor, "K": K, "x_grids": x_grids, "y_grids": y_grids, "n_x": n_x, "n_y": n_y, "rs_factor": get_np(tran.rs_factor), "rs": rs, "train_rs": train_rs, "train_model": train_model, "pbar_update_interval": pbar_update_interval, "video_clip_start": video_clip_start, "video_clip_end": video_clip_end, "held_out_proportion": held_out_proportion, "torch_seed": torch_seed, "np_seed": np_seed, "list_of_num_iters": list_of_num_iters, "list_of_lr": list_of_lr, "list_of_k_steps": list_of_k_steps, "sample_T": sample_T, "quiver_scale": quiver_scale } print("Experiment params:") print(exp_params) rslt_dir = addDateTime("rslts/gpgrid/" + job_name) rslt_dir = os.path.join(repo_dir, rslt_dir) if not os.path.exists(rslt_dir): os.makedirs(rslt_dir) print("Making result directory...") print("Saving to rlst_dir: ", rslt_dir) with open(rslt_dir + "/exp_params.json", "w") as f: json.dump(exp_params, f, indent=4, cls=NumpyEncoder) # compute memory print("Computing memory...") def get_memory_kwargs(data, train_rs): feature_vecs_a = f_corner_vec_func(data[:-1, 0:2]) feature_vecs_b = f_corner_vec_func(data[:-1, 2:4]) gpt_idx_a, grid_idx_a = tran.get_gpt_idx_and_grid_idx_for_batch( data[:-1, 0:2]) gpt_idx_b, grid_idx_b = tran.get_gpt_idx_and_grid_idx_for_batch( data[:-1, 2:4]) if train_rs: nearby_gpts_a = tran.gridpoints[gpt_idx_a] dist_sq_a = (data[:-1, None, 0:2] - nearby_gpts_a)**2 nearby_gpts_b = tran.gridpoints[gpt_idx_b] dist_sq_b = (data[:-1, None, 2:4] - nearby_gpts_b)**2 return dict(feature_vecs_a=feature_vecs_a, feature_vecs_b=feature_vecs_b, gpt_idx_a=gpt_idx_a, gpt_idx_b=gpt_idx_b, grid_idx_a=grid_idx_a, grid_idx_b=grid_idx_b, dist_sq_a=dist_sq_a, dist_sq_b=dist_sq_b) else: coeff_a = tran.get_gp_coefficients(data[:-1, 0:2], 0, gpt_idx_a, grid_idx_a) coeff_b = tran.get_gp_coefficients(data[:-1, 2:4], 0, gpt_idx_b, grid_idx_b) return dict(feature_vecs_a=feature_vecs_a, feature_vecs_b=feature_vecs_b, gpt_idx_a=gpt_idx_a, gpt_idx_b=gpt_idx_b, grid_idx_a=grid_idx_a, grid_idx_b=grid_idx_b, coeff_a=coeff_a, coeff_b=coeff_b) memory_kwargs = get_memory_kwargs(training_data, train_rs) valid_data_memory_kwargs = get_memory_kwargs(valid_data, train_rs) log_prob = model.log_likelihood(training_data, **memory_kwargs) ##################### training ############################ if train_model: print("start training") list_of_losses = [] if load_opt_dir != "": opt = joblib.load(load_opt_dir) else: opt = None for i, (num_iters, lr) in enumerate(zip(list_of_num_iters, list_of_lr)): training_losses, opt, valid_losses = model.fit( training_data, optimizer=opt, method='adam', num_iters=num_iters, lr=lr, pbar_update_interval=pbar_update_interval, valid_data=valid_data, valid_data_memory_kwargs=valid_data_memory_kwargs, **memory_kwargs) list_of_losses.append(training_losses) checkpoint_dir = rslt_dir + "/checkpoint_{}".format(i) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) print("Creating checkpoint_{} directory...".format(i)) # save model and opt joblib.dump(model, checkpoint_dir + "/model") joblib.dump(opt, checkpoint_dir + "/optimizer") # save losses losses = dict(training_loss=training_losses, valid_loss=valid_losses) joblib.dump(losses, checkpoint_dir + "/losses") plt.figure() plt.plot(training_losses) plt.title("training loss") plt.savefig(checkpoint_dir + "/training_losses.jpg") plt.close() plt.figure() plt.plot(valid_losses) plt.title("validation loss") plt.savefig(checkpoint_dir + "/valid_losses.jpg") plt.close() # save rest if i in ckpts_not_to_save: print("ckpt {}: skip!\n".format(i)) continue with torch.no_grad(): rslt_saving(rslt_dir=checkpoint_dir, model=model, data=training_data, memory_kwargs=memory_kwargs, list_of_k_steps=list_of_k_steps, sample_T=sample_T, quiver_scale=quiver_scale, valid_data=valid_data, valid_data_memory_kwargs=valid_data_memory_kwargs, device=device) else: # only save the results rslt_saving(rslt_dir=rslt_dir, model=model, data=training_data, memory_kwargs=memory_kwargs, list_of_k_steps=list_of_k_steps, sample_T=sample_T, quiver_scale=quiver_scale, valid_data=valid_data, valid_data_memory_kwargs=valid_data_memory_kwargs, device=device) print("Finish running!")
def main(job_name, downsample_n, filter_traj, load_model, load_model_dir, load_opt_dir, transition, sticky_alpha, sticky_kappa, acc_factor, k, x_grids, y_grids, n_x, n_y, train_model, pbar_update_interval, video_clips, torch_seed, np_seed, list_of_num_iters, list_of_lr, sample_t, quiver_scale): if job_name is None: raise ValueError("Please provide the job name.") K = k sample_T = sample_t video_clip_start, video_clip_end = [int(x) for x in video_clips.split(",")] list_of_num_iters = [int(x) for x in list_of_num_iters.split(",")] list_of_lr = [float(x) for x in list_of_lr.split(",")] assert len(list_of_num_iters) == len(list_of_lr), "Length of list_of_num_iters must match length of list-of_lr." for lr in list_of_lr: if lr > 1: raise ValueError("Learning rate should not be larger than 1!") repo = git.Repo('.', search_parent_directories=True) # SocialBehaviorectories=True) repo_dir = repo.working_tree_dir # SocialBehavior torch.manual_seed(torch_seed) np.random.seed(np_seed) ########################## data ######################## data_dir = repo_dir + '/SocialBehaviorptc/data/trajs_all' trajs = joblib.load(data_dir) traj = trajs[36000*video_clip_start:36000*video_clip_end] traj = downsample(traj, downsample_n) if filter_traj: traj = filter_traj_by_speed(traj, q1=0.99, q2=0.99) data = torch.tensor(traj, dtype=torch.float64) ######################### model #################### # model D = 4 M = 0 Df = 4 if load_model: print("Loading the model from ", load_model_dir) model = joblib.load(load_model_dir) tran = model.observation.transformation K = model.K n_x = len(tran.x_grids) - 1 n_y = len(tran.y_grids) - 1 acc_factor = tran.transformations_a[0].acc_factor else: print("Creating the model...") bounds = np.array([[ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX], [ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX]]) # grids if x_grids is None: x_grid_gap = (ARENA_XMAX - ARENA_XMIN) / n_x x_grids = [ARENA_XMIN + i * x_grid_gap for i in range(n_x + 1)] else: x_grids = [float(x) for x in x_grids.split(",")] n_x = len(x_grids) - 1 if y_grids is None: y_grid_gap = (ARENA_YMAX - ARENA_YMIN) / n_y y_grids = [ARENA_YMIN + i * y_grid_gap for i in range(n_y + 1)] else: y_grids = [float(x) for x in y_grids.split(",")] n_y = len(y_grids) - 1 if acc_factor is None: acc_factor = downsample_n * 10 tran = GridTransformation(K=K, D=D, x_grids=x_grids, y_grids=y_grids, unit_transformation="direction", Df=Df, feature_vec_func=f_corner_vec_func, acc_factor=acc_factor) obs = ARTruncatedNormalObservation(K=K, D=D, M=M, lags=1, bounds=bounds, transformation=tran) if transition == 'sticky': transition_kwargs = dict(alpha=sticky_alpha, kappa=sticky_kappa) else: transition_kwargs = None model = HMM(K=K, D=D, M=M, transition=transition, observation=obs, transition_kwargs=transition_kwargs) model.observation.mus_init = data[0] * torch.ones(K, D, dtype=torch.float64) # save experiment params exp_params = {"job_name": job_name, 'downsample_n': downsample_n, "load_model": load_model, "load_model_dir": load_model_dir, "load_opt_dir": load_opt_dir, "transition": transition, "sticky_alpha": sticky_alpha, "sticky_kappa": sticky_kappa, "acc_factor": acc_factor, "K": K, "n_x": n_x, "n_y": n_y, "x_grids": x_grids, "y_grids": y_grids, "train_model": train_model, "pbar_update_interval": pbar_update_interval, "list_of_num_iters": list_of_num_iters, "list_of_lr": list_of_lr, "video_clip_start": video_clip_start, "video_clip_end": video_clip_end, "sample_T": sample_T} print("Experiment params:") print(exp_params) rslt_dir = addDateTime("rslts/" + job_name) rslt_dir = os.path.join(repo_dir, rslt_dir) if not os.path.exists(rslt_dir): os.makedirs(rslt_dir) print("Making result directory...") print("Saving to rlst_dir: ", rslt_dir) with open(rslt_dir+"/exp_params.json", "w") as f: json.dump(exp_params, f, indent=4, cls=NumpyEncoder) # compute memories masks_a, masks_b = tran.get_masks(data[:-1]) feature_vecs_a = f_corner_vec_func(data[:-1, 0:2]) feature_vecs_b = f_corner_vec_func(data[:-1, 2:4]) m_kwargs_a = dict(feature_vecs=feature_vecs_a) m_kwargs_b = dict(feature_vecs=feature_vecs_b) ##################### training ############################ if train_model: print("start training") list_of_losses = [] if load_opt_dir != "": opt = joblib.load(load_opt_dir) else: opt = None for i, (num_iters, lr) in enumerate(zip(list_of_num_iters, list_of_lr)): losses, opt = model.fit(data, optimizer=opt, method='adam', num_iters=num_iters, lr=lr, masks=(masks_a, masks_b), pbar_update_interval=pbar_update_interval, memory_kwargs_a=m_kwargs_a, memory_kwargs_b=m_kwargs_b) list_of_losses.append(losses) checkpoint_dir = rslt_dir + "/checkpoint_{}".format(i) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) print("Creating checkpoint_{} directory...".format(i)) # save model and opt joblib.dump(model, checkpoint_dir+"/model") joblib.dump(opt, checkpoint_dir+"/optimizer") # save rest rslt_saving(checkpoint_dir, model, Df, data, masks_a, masks_b, m_kwargs_a, m_kwargs_b, sample_T, train_model, losses, quiver_scale) else: # only save the results rslt_saving(rslt_dir, model, Df, data, masks_a, masks_b, m_kwargs_a, m_kwargs_b, sample_T, False, [], quiver_scale) print("Finish running!")
npr.seed(0) bounds = np.array([[0, 20], [0, 40]]) thetas = np.linspace(0, 2 * np.pi, K, endpoint=False) mus_init = 3 * np.column_stack((np.cos(thetas), np.sin(thetas))) + 5 true_tran = LinearTransformation(K=K, D=D, lags=lags) true_observation = ARTruncatedNormalObservation(K=K, D=D, M=0, lags=lags, transformation=true_tran, bounds=bounds, mus_init=mus_init, train_sigma=False) true_model = HMM(K=K, D=D, M=0, observation=true_observation) z, x = true_model.sample(T, return_np=False) true_ll = true_model.log_likelihood(x) print(true_ll) print("\n # model parameters: \n", len(true_model.params)) print("\n # model trainable parameters: \n", len(true_model.trainable_params)) sample_z, sample_x = true_model.sample(T) """ # learning tran = LinearTransformation(K=K, D=D, momentum_lags=1, As=As)
from hips.plotting.colormaps import gradient_cmap, white_to_color_cmap color_names = [ "windows blue", "red", "amber", "faded green", "dusty purple", "orange" ] colors = sns.xkcd_palette(color_names) cmap = gradient_cmap(colors) npr.seed(0) torch.manual_seed(0) K = 3 D = 2 M = 1 T = 100 # Create an exogenous input inpt = np.sin(2 * np.pi * np.arange(T) / 50)[:, None] + 1e-1 * npr.randn(T, M) inpt = torch.tensor(inpt, dtype=torch.float64) true_model = HMM(K=K, D=D, M=M, transition='inputdriven', observation='gaussian') z, data = true_model.sample(T, input=inpt, return_np=True) lls = true_model.log_likelihood(data, inpt)
K = 3 D = 2 T = 100 As = [random_rotation(D) for _ in range(K)] true_tran = LinearTransformation(K=K, d_in=D, D=D, As=As) bounds = np.array([[0, 20], [-5, 25]]) true_observation = ARLogitNormalObservation(K=K, D=D, M=0, transformation=true_tran, bounds=bounds) true_model = HMM(K=K, D=D, M=0, observation=true_observation) z, data = true_model.sample(T, return_np=False) # Define a model to fit the data tran = LinearTransformation(K=K, d_in=D, D=D) observation = ARLogitNormalObservation(K=K, D=D, M=0, transformation=tran, bounds=bounds) model = HMM(K=K, D=D, M=0, observation=observation) # Model fitting
[4.0, 7.0, 8.0, 5.0], [6.0, 7.0, 5.0, 6.0], [8.0, 2.0, 6.0, 1.0]]) data = torch.tensor(data, dtype=torch.float64) bounds = np.array([[0.0, 10.0], [0.0, 8.0], [0.0, 10.0], [0.0, 8.0]]) # model K = 3 D = 4 M = 0 obs = TruncatedNormalObservation(K=K, D=D, M=M, bounds=bounds) # model model = HMM(K=K, D=D, M=M, observation=obs) # log like log_prob = model.log_probability(data) print("log probability = ", log_prob) # training print("start training...") num_iters = 10 losses, opt = model.fit(data, num_iters=num_iters, lr=0.001) # sampling samplt_T = T print("start sampling") sample_z, sample_x = model.sample(samplt_T)
from ssm_ptc.models.hmm import HMM from ssm_ptc.observations.ar_gaussian_observation import ARGaussianObservation from ssm_ptc.transformations.linear import LinearTransformation import joblib torch.manual_seed(0) np.random.seed(0) K = 3 D = 2 lags = 10 # AR Gaussian trans1 = LinearTransformation(K=K, D=D, lags=lags) obs1 = ARGaussianObservation(K=K, D=D, transformation=trans1, train_sigma=False) model = HMM(K=K, D=D, observation=obs1) filename = "test_save_model" joblib.dump(model, filename) model_recovered = joblib.load(filename) for p1, p2 in zip(model.params_unpack, model_recovered.params_unpack): assert torch.all(torch.eq(p1, p2))
def main(job_name, downsample_n, mouse, filter_traj, load_model, load_model_dir, load_opt_dir, transition, sticky_alpha, sticky_kappa, k, x_grids, y_grids, n_x, n_y, not_train_mu, initialize_mu, train_model, pbar_update_interval, video_clips, torch_seed, np_seed, list_of_num_iters, list_of_lr, sample_t, quiver_scale): if job_name is None: raise ValueError("Please provide the job name.") K = k sample_T = sample_t video_clip_start, video_clip_end = [int(x) for x in video_clips.split(",")] list_of_num_iters = [int(x) for x in list_of_num_iters.split(",")] list_of_lr = [float(x) for x in list_of_lr.split(",")] assert len(list_of_num_iters) == len( list_of_lr ), "Length of list_of_num_iters must match length of list-of_lr." for lr in list_of_lr: if lr > 1: raise ValueError("Learning rate should not be larger than 1!") repo = git.Repo( '.', search_parent_directories=True) # SocialBehaviorectories=True) repo_dir = repo.working_tree_dir # SocialBehavior torch.manual_seed(torch_seed) np.random.seed(np_seed) # specify grids for sanity checks and plotting if x_grids is None: x_grid_gap = (ARENA_XMAX - ARENA_XMIN) / n_x x_grids = np.array( [ARENA_XMIN + i * x_grid_gap for i in range(n_x + 1)]) else: x_grids = np.array([float(x) for x in x_grids.split(",")]) n_x = len(x_grids) - 1 if y_grids is None: y_grid_gap = (ARENA_YMAX - ARENA_YMIN) / n_y y_grids = np.array( [ARENA_YMIN + i * y_grid_gap for i in range(n_y + 1)]) else: y_grids = np.array([float(x) for x in y_grids.split(",")]) n_y = len(y_grids) - 1 ########################## data ######################## data_dir = repo_dir + '/SocialBehaviorptc/data/trajs_all' trajs = joblib.load(data_dir) traj = trajs[36000 * video_clip_start:36000 * video_clip_end] traj = downsample(traj, downsample_n) if filter_traj: traj = filter_traj_by_speed(traj, q1=0.99, q2=0.99) if mouse == "both": data = torch.tensor(traj, dtype=torch.float64) D = 4 elif mouse == "virgin": data = torch.tensor(traj[:, 0:2], dtype=torch.float64) D = 2 elif mouse == "mother": data = torch.tensor(traj[:, 2:4], dtype=torch.float64) D = 2 else: raise ValueError( "mouse must be chosen from 'both', 'virgin', 'mother'.") ######################### model #################### # model M = 0 if load_model: print("Loading the model from ", load_model_dir) model = joblib.load(load_model_dir) K = model.K cluster_locations = get_np(model.observation.mus) else: print("Creating the model...") if D == 4: bounds = np.array([[ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX], [ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX]]) else: bounds = np.array([[ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX]]) if transition == 'sticky': transition_kwargs = dict(alpha=sticky_alpha, kappa=sticky_kappa) else: transition_kwargs = None obs = DynamicLocationObservation(K=K, D=D, M=M, bounds=bounds) model = HMM(K=K, D=D, M=M, transition=transition, observation=obs, transition_kwargs=transition_kwargs) if initialize_mu: if mouse == 'both': assert K == (n_x * n_y)**2, "K should be equal to (n_x*n_y)**2" else: assert K == n_x * n_y, "K should be equal to n_x * n_y" cluster_locations = np.array([((x_grids[i] + x_grids[i + 1]) / 2, (y_grids[j] + y_grids[j + 1]) / 2) for i in range(n_x) for j in range(n_y)]) if mouse == "both": cluster_locations = np.array([ np.concatenate((loc_a, loc_b)) for loc_a in cluster_locations for loc_b in cluster_locations ]) assert cluster_locations.shape == (K, D) train_mu = not not_train_mu model.observation.transformation.mus_loc = \ torch.tensor(cluster_locations, dtype=torch.float64, requires_grad=train_mu) else: cluster_locations = get_np( model.observation.transformation.mus_loc) # save experiment params exp_params = { "job_name": job_name, 'downsample_n': downsample_n, "mouse": mouse, "load_model": load_model, "load_model_dir": load_model_dir, "load_opt_dir": load_opt_dir, "transition": transition, "sticky_alpha": sticky_alpha, "sticky_kappa": sticky_kappa, "K": K, "n_x": n_x, "n_y": n_y, "x_grids": x_grids, "y_grids": y_grids, "cluster_locations": cluster_locations, "not_train_mu": not_train_mu, "initialize_mu": initialize_mu, "train_model": train_model, "pbar_update_interval": pbar_update_interval, "list_of_num_iters": list_of_num_iters, "list_of_lr": list_of_lr, "video_clip_start": video_clip_start, "video_clip_end": video_clip_end, "sample_T": sample_T } print("Experiment params:") print(exp_params) rslt_dir = addDateTime("rslts/dynamic_loc/" + job_name) rslt_dir = os.path.join(repo_dir, rslt_dir) if not os.path.exists(rslt_dir): os.makedirs(rslt_dir) print("Making result directory...") print("Saving to rlst_dir: ", rslt_dir) with open(rslt_dir + "/exp_params.json", "w") as f: json.dump(exp_params, f, indent=4, cls=NumpyEncoder) ##################### training ############################ if train_model: initial_dir = rslt_dir + "/initial" if not os.path.exists(initial_dir): os.makedirs(initial_dir) print("\nCreating initial directory...") rslt_saving(initial_dir, model, data, mouse, sample_T, False, [], quiver_scale, x_grids, y_grids) print("start training") list_of_losses = [] if load_opt_dir != "": opt = joblib.load(load_opt_dir) else: opt = None for i, (num_iters, lr) in enumerate(zip(list_of_num_iters, list_of_lr)): losses, opt = model.fit(data, optimizer=opt, method='adam', num_iters=num_iters, lr=lr, pbar_update_interval=pbar_update_interval) list_of_losses.append(losses) checkpoint_dir = rslt_dir + "/checkpoint_{}".format(i) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) print("Creating checkpoint_{} directory...".format(i)) # save model and opt joblib.dump(model, checkpoint_dir + "/model") joblib.dump(opt, checkpoint_dir + "/optimizer") # save rest rslt_saving(checkpoint_dir, model, data, mouse, sample_T, train_model, losses, quiver_scale, x_grids, y_grids) else: # only save the results rslt_saving(rslt_dir, model, data, mouse, sample_T, False, [], quiver_scale, x_grids, y_grids) print("Finish running!")
x_grids=x_grids, y_grids=y_grids, unit_transformation="direction", Df=Df, feature_vec_func=f_corner_vec_func, acc_factor=10) obs = ARTruncatedNormalObservation(K=K, D=D, M=M, lags=1, bounds=bounds, transformation=tran) # model model = HMM(K=K, D=D, M=M, observation=obs) #model.observation.mus_init = data[0] * torch.ones(K, D, dtype=torch.float64) params = joblib.load( "/Users/leah/Columbia/courses/19summer/SocialBehavior/SocialBehaviorptc/project_notebooks/gridmodel/model_k2" ) model.params = params obs.log_sigmas = torch.log(torch.ones((2, 4), dtype=torch.float64) * 0.01) center_z = torch.tensor([0], dtype=torch.int) center_x = torch.tensor([[150, 190, 200, 200]], dtype=torch.float64) sample_z_center, sample_x_center = model.sample(20,
def test_model(): torch.manual_seed(0) np.random.seed(0) T = 100 D = 4 # data = np.array([[1.0, 1.0, 1.0, 6.0], [3.0, 6.0, 8.0, 6.0], # [4.0, 7.0, 8.0, 5.0], [6.0, 7.0, 5.0, 6.0], [8.0, 2.0, 6.0, 1.0]]) data = np.random.randn(T, D) data = torch.tensor(data, dtype=torch.float64) xmax = max(np.max(data[:, 0].numpy()), np.max(data[:, 2].numpy())) xmin = min(np.min(data[:, 0].numpy()), np.min(data[:, 2].numpy())) ymax = max(np.max(data[:, 1].numpy()), np.max(data[:, 3].numpy())) ymin = min(np.min(data[:, 1].numpy()), np.min(data[:, 3].numpy())) bounds = np.array([[xmin - 1, xmax + 1], [ymin - 1, ymax + 1], [xmin - 1, xmax + 1], [ymin - 1, ymax + 1]]) def toy_feature_vec_func(s): """ :param s: self, (T, 2) :param o: other, (T, 2) :return: features, (T, Df, 2) """ corners = torch.tensor([[0, 0], [0, 8], [10, 0], [10, 8]], dtype=torch.float64) return feature_direction_vec(s, corners) K = 3 Df = 4 lags = 1 tran = UniLSTMTransformation(K=K, D=D, Df=Df, feature_vec_func=toy_feature_vec_func, lags=lags, dh=10) # observation obs = ARTruncatedNormalObservation(K=K, D=D, lags=lags, bounds=bounds, transformation=tran) # model model = HMM(K=K, D=D, observation=obs) print("calculating log likelihood") feature_vecs_a = toy_feature_vec_func(data[:-1, 0:2]) feature_vecs_b = toy_feature_vec_func(data[:-1, 2:4]) feature_vecs = (feature_vecs_a, feature_vecs_b) packed_data = get_packed_data((data[:-1]), lags=lags) model.log_likelihood(data, feature_vecs=feature_vecs, packed_data=packed_data) # fit losses, _ = model.fit(data, optimizer=None, method="adam", num_iters=50, feature_vecs=feature_vecs, packed_data=packed_data) plt.figure() plt.plot(losses) plt.show() # most-likely-z print("Most likely z...") z = model.most_likely_states(data, feature_vecs=feature_vecs, packed_data=packed_data) # prediction if data.shape[0] <= 1000: data_to_predict = data else: data_to_predict = data[-1000:] print("0 step prediction") if data.shape[0] <= 1000: data_to_predict = data else: data_to_predict = data[-1000:] x_predict = k_step_prediction_for_lstm_model(model, z, data_to_predict, feature_vecs=feature_vecs) x_predict_err = np.mean(np.abs(x_predict - data_to_predict.numpy()), axis=0) print("10 step prediction") x_predict_2 = k_step_prediction(model, z, data_to_predict, k=10) x_predict_2_err = np.mean(np.abs(x_predict_2 - data_to_predict[10:].numpy()), axis=0) # samples print("sampling...") sample_T = 5 sample_z, sample_x = model.sample(sample_T)
def main(job_name, cuda_num, data_type, downsample_n, filter_traj, lg_version, use_log_prior, no_boundary_prior, add_log_diagonal_prior, log_prior_sigma_sq, load_model, load_model_dir, load_opt_dir, animal, reset_prior_info, transition, sticky_alpha, sticky_kappa, acc_factor, k, x_grids, y_grids, n_x, n_y, train_model, pbar_update_interval, prop_start_end, video_clips, held_out_proportion, torch_seed, np_seed, list_of_num_iters, ckpts_not_to_save, list_of_lr, list_of_k_steps, sample_t, quiver_scale): if job_name is None: raise ValueError("Please provide the job name.") assert animal in ['both', 'virgin', 'mother'], animal cuda_num = int(cuda_num) device = torch.device("cuda:{}".format(cuda_num) if torch.cuda.is_available() else "cpu") print("Using device {} \n\n".format(device)) K = k sample_T = sample_t video_clip_start, video_clip_end = [float(x) for x in video_clips.split(",")] start, end = [float(x) for x in prop_start_end.split(",")] log_prior_sigma_sq = float(log_prior_sigma_sq) list_of_num_iters = [int(x) for x in list_of_num_iters.split(",")] list_of_lr = [float(x) for x in list_of_lr.split(",")] # TODO: test if that works list_of_k_steps = [int(x) for x in list_of_k_steps.split(",")] if list_of_k_steps else [] assert len(list_of_num_iters) == len(list_of_lr), "Length of list_of_num_iters must match length of list_of_lr." for lr in list_of_lr: if lr > 1: raise ValueError("Learning rate should not be larger than 1!") ckpts_not_to_save = [int(x) for x in ckpts_not_to_save.split(',')] if ckpts_not_to_save else [] repo = git.Repo('.', search_parent_directories=True) # SocialBehaviorectories=True) repo_dir = repo.working_tree_dir # SocialBehavior torch.manual_seed(torch_seed) np.random.seed(np_seed) ########################## data ######################## if data_type == 'full': data_dir = repo_dir + '/SocialBehaviorptc/data/trajs_all' traj = joblib.load(data_dir) traj = traj[int(36000 * video_clip_start):int(36000 * video_clip_end)] elif data_type == 'selected_010_virgin': assert animal == 'virgin', "animal much be 'virgin', but got {}.".format(animal) data_dir = repo_dir + '/SocialBehaviorptc/data/traj_010_virgin_selected' traj = joblib.load(data_dir) T = len(traj) traj = traj[int(T*start): int(T*end)] else: raise ValueError("unsupported data type: {}".format(data_type)) if animal == 'virgin': traj = traj[:,0:2] elif animal == 'mother': traj = traj[:,2:4] traj = downsample(traj, downsample_n) if filter_traj: traj = filter_traj_by_speed(traj, q1=0.99, q2=0.99) data = torch.tensor(traj, dtype=torch.float64, device=device) assert 0 <= held_out_proportion <= 0.4, \ "held_out-portion should be between 0 and 0.4 (inclusive), but is {}".format(held_out_proportion) T = data.shape[0] breakpoint = int(T*(1-held_out_proportion)) training_data = data[:breakpoint] valid_data = data[breakpoint:] sample_T = training_data.shape[0] ######################### model #################### # model D = data.shape[1] assert D == 2 or 4, D M = 0 Df = 4 if load_model: print("Loading the model from ", load_model_dir) if reset_prior_info: pretrained_model = joblib.load(load_model_dir) pretrained_transition = pretrained_model.transition pretrained_observation = pretrained_model.observation pretrained_tran = pretrained_model.observation.transformation # set prior info pretrained_tran.use_log_prior = use_log_prior pretrained_tran.no_boundary_prior = no_boundary_prior pretrained_tran.add_log_diagonal_prior = add_log_diagonal_prior pretrained_tran.log_prior_sigma_sq = torch.tensor(log_prior_sigma_sq, dtype=torch.float64, device=device) acc_factor = pretrained_tran.acc_factor K = pretrained_model.K obs = ARTruncatedNormalObservation(K=K, D=D, M=0, obs=pretrained_observation, device=device) tran = obs.transformation if transition == 'sticky': transition_kwargs = dict(alpha=sticky_alpha, kappa=sticky_kappa) else: transition_kwargs = None model = HMM(K=K, D=D, M=M, pi0=get_np(pretrained_model.pi0), Pi=get_np(pretrained_transition.Pi), transition=transition, observation=obs, transition_kwargs=transition_kwargs, device=device) model.observation.mus_init = training_data[0] * torch.ones(K, D, dtype=torch.float64, device=device) else: model = joblib.load(load_model_dir) tran = model.observation.transformation K = model.K n_x = len(tran.x_grids) - 1 n_y = len(tran.y_grids) - 1 if isinstance(model, LinearGridTransformation): acc_factor = tran.acc_factor use_log_prior = tran.use_log_prior no_boundary_prior = tran.no_boundary_prior add_log_diagonal_prior = tran.add_log_diagonal_prior log_prior_sigma_sq = get_np(tran.log_prior_sigma_sq) else: if D == 4: bounds = np.array([[ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX], [ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX]]) else: bounds = np.array([[ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX]]) # grids if x_grids is None: x_grid_gap = (ARENA_XMAX - ARENA_XMIN) / n_x x_grids = np.array([ARENA_XMIN + i * x_grid_gap for i in range(n_x + 1)]) else: x_grids = np.array([float(x) for x in x_grids.split(",")]) n_x = len(x_grids) - 1 if y_grids is None: y_grid_gap = (ARENA_YMAX - ARENA_YMIN) / n_y y_grids = np.array([ARENA_YMIN + i * y_grid_gap for i in range(n_y + 1)]) else: y_grids = np.array([float(x) for x in y_grids.split(",")]) n_y = len(y_grids) - 1 if acc_factor is None: acc_factor = downsample_n * 10 print("Creating the model...") if animal == 'both': tran = LinearGridTransformation(K=K, D=D, x_grids=x_grids, y_grids=y_grids, Df=Df, feature_vec_func=f_corner_vec_func, acc_factor=acc_factor, use_log_prior=use_log_prior, no_boundary_prior=no_boundary_prior, add_log_diagonal_prior=add_log_diagonal_prior, log_prior_sigma_sq=log_prior_sigma_sq, device=device, version=lg_version) else: tran = SingleLinearGridTransformation(K=K, D=D, x_grids=x_grids, y_grids=y_grids, device=device) obs = ARTruncatedNormalObservation(K=K, D=D, M=M, lags=1, bounds=bounds, transformation=tran, device=device) if transition == 'sticky': transition_kwargs = dict(alpha=sticky_alpha, kappa=sticky_kappa) elif transition == 'grid': transition_kwargs = dict(x_grids=x_grids, y_grids=y_grids) else: transition_kwargs = None model = HMM(K=K, D=D, M=M, transition=transition, observation=obs, transition_kwargs=transition_kwargs, device=device) model.observation.mus_init = training_data[0] * torch.ones(K, D, dtype=torch.float64, device=device) # save experiment params exp_params = {"job_name": job_name, 'downsample_n': downsample_n, 'data_type': data_type, "filter_traj": filter_traj, "lg_version": lg_version, "use_log_prior": use_log_prior, "add_log_diagonal_prior": add_log_diagonal_prior, "no_boundary_prior": no_boundary_prior, "log_prior_sigma_sq": log_prior_sigma_sq, "load_model": load_model, "load_model_dir": load_model_dir, "animal": animal, "reset_prior_info": reset_prior_info, "load_opt_dir": load_opt_dir, "transition": transition, "sticky_alpha": sticky_alpha, "sticky_kappa": sticky_kappa, "acc_factor": acc_factor, "K": K, "x_grids": x_grids, "y_grids": y_grids, "n_x": n_x, "n_y": n_y, "train_model": train_model, "pbar_update_interval": pbar_update_interval, "video_clip_start": video_clip_start, "video_clip_end": video_clip_end, "start_percentage": start, "end_percentage": end, "held_out_proportion": held_out_proportion, "torch_seed": torch_seed, "np_seed": np_seed, "list_of_num_iters": list_of_num_iters, "ckpts_not_to_save": ckpts_not_to_save, "list_of_lr": list_of_lr, "list_of_k_steps": list_of_k_steps, "sample_T": sample_T, "quiver_scale": quiver_scale} print("Experiment params:") print(exp_params) rslt_dir = addDateTime("rslts/lineargrid/{}/{}".format(animal, job_name)) rslt_dir = os.path.join(repo_dir, rslt_dir) if not os.path.exists(rslt_dir): os.makedirs(rslt_dir) print("Making result directory...") print("Saving to rlst_dir: ", rslt_dir) with open(rslt_dir+"/exp_params.json", "w") as f: json.dump(exp_params, f, indent=4, cls=NumpyEncoder) # compute memory if transition == "grid": print("Computing transition memory...") joint_grid_idx = model.transition.get_grid_idx(training_data[:-1]) transition_memory_kwargs = dict(joint_grid_idx=joint_grid_idx) valid_joint_grid_idx = model.transition.get_grid_idx(valid_data[:-1]) valid_data_transition_memory_kwargs = dict(joint_grid_idx=valid_joint_grid_idx) else: transition_memory_kwargs = None valid_data_transition_memory_kwargs = None print("Computing observation memory...") if animal == 'both': gridpoints_idx_a = tran.get_gridpoints_idx_for_batch(training_data[:-1, 0:2]) # (T-1, n_gps, 4) gridpoints_idx_b = tran.get_gridpoints_idx_for_batch(training_data[:-1, 2:4]) # (T-1, n_gps, 4) gridpoints_a = tran.get_gridpoints_for_batch(gridpoints_idx_a) # (T-1, d, 2) gridpoints_b = tran.get_gridpoints_for_batch(gridpoints_idx_b) # (T-1, d, 2) feature_vecs_a = f_corner_vec_func(training_data[:-1, 0:2]) # (T, Df, 2) feature_vecs_b = f_corner_vec_func(training_data[:-1, 2:4]) # (T, Df, 2) gridpoints_idx = (gridpoints_idx_a, gridpoints_idx_b) gridpoints = (gridpoints_a, gridpoints_b) feature_vecs = (feature_vecs_a, feature_vecs_b) memory_kwargs = dict(gridpoints_idx=gridpoints_idx, gridpoints=gridpoints, feature_vecs=feature_vecs) if len(valid_data) > 0: gridpoints_idx_a_v = tran.get_gridpoints_idx_for_batch(valid_data[:-1, 0:2]) # (T-1, n_gps, 4) gridpoints_idx_b_v = tran.get_gridpoints_idx_for_batch(valid_data[:-1, 2:4]) # (T-1, n_gps, 4) gridpoints_a_v = tran.get_gridpoints_for_batch(gridpoints_idx_a_v) # (T-1, d, 2) gridpoints_b_v = tran.get_gridpoints_for_batch(gridpoints_idx_b_v) # (T-1, d, 2) feature_vecs_a_v = f_corner_vec_func(valid_data[:-1, 0:2]) # (T, Df, 2) feature_vecs_b_v = f_corner_vec_func(valid_data[:-1, 2:4]) # (T, Df, 2) gridpoints_idx_v = (gridpoints_idx_a_v, gridpoints_idx_b_v) gridpoints_v = (gridpoints_a_v, gridpoints_b_v) feature_vecs_v = (feature_vecs_a_v, feature_vecs_b_v) valid_data_memory_kwargs = dict(gridpoints_idx=gridpoints_idx_v, gridpoints=gridpoints_v, feature_vecs=feature_vecs_v) else: valid_data_memory_kwargs = {} else: def get_memory_kwargs(data): if data is None or data.shape[0] == 0: return {} gridpoints_idx = tran.get_gridpoints_idx_for_batch(data) gridpoints = tran.gridpoints[gridpoints_idx] coeffs = tran.get_lp_coefficients(data, gridpoints[:, 0], gridpoints[:, 3], device=device) return dict(gridpoints_idx=gridpoints_idx, coeffs=coeffs) memory_kwargs = get_memory_kwargs(training_data[:-1]) valid_data_memory_kwargs = get_memory_kwargs(valid_data[:-1]) log_prob = model.log_likelihood(training_data, transition_memory_kwargs=transition_memory_kwargs, **memory_kwargs) print("log_prob = {}".format(log_prob)) ##################### training ############################ if train_model: print("start training") list_of_losses = [] if load_opt_dir != "": opt = joblib.load(load_opt_dir) else: opt = None for i, (num_iters, lr) in enumerate(zip(list_of_num_iters, list_of_lr)): training_losses, opt, valid_losses, _ = \ model.fit(training_data, optimizer=opt, method='adam', num_iters=num_iters, lr=lr, pbar_update_interval=pbar_update_interval, valid_data=valid_data, transition_memory_kwargs=transition_memory_kwargs, valid_data_transition_memory_kwargs=valid_data_transition_memory_kwargs, valid_data_memory_kwargs=valid_data_memory_kwargs, **memory_kwargs) list_of_losses.append(training_losses) checkpoint_dir = rslt_dir + "/checkpoint_{}".format(i) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) print("Creating checkpoint_{} directory...".format(i)) # save model and opt joblib.dump(model, checkpoint_dir+"/model") joblib.dump(opt, checkpoint_dir+"/optimizer") # save losses losses = dict(training_loss=training_losses, valid_loss=valid_losses) joblib.dump(losses, checkpoint_dir+"/losses") plt.figure() plt.plot(training_losses) plt.title("training loss") plt.savefig(checkpoint_dir + "/training_losses.jpg") plt.close() plt.figure() plt.plot(valid_losses) plt.title("validation loss") plt.savefig(checkpoint_dir + "/valid_losses.jpg") plt.close() # save rest if i in ckpts_not_to_save: print("ckpt {}: skip!\n".format(i)) continue with torch.no_grad(): rslt_saving(rslt_dir=checkpoint_dir, model=model, data=training_data, animal=animal, memory_kwargs=memory_kwargs, list_of_k_steps=list_of_k_steps, sample_T=sample_T, quiver_scale=quiver_scale, valid_data=valid_data, valid_data_memory_kwargs=valid_data_memory_kwargs, device=device) else: # only save the results rslt_saving(rslt_dir=rslt_dir, model=model, data=training_data, animal=animal, memory_kwargs=memory_kwargs, list_of_k_steps=list_of_k_steps, sample_T=sample_T, quiver_scale=quiver_scale, valid_data=valid_data, valid_data_memory_kwargs=valid_data_memory_kwargs, device=device) print("Finish running!")
from ssm_ptc.utils import k_step_prediction import matplotlib.pyplot as plt torch.manual_seed(0) np.random.seed(0) # test fitting K = 3 D = 2 lags= 5 trans1 = LinearTransformation(K=K, D=D, lags=lags) obs1 = ARGaussianObservation(K=K, D=D, transformation=trans1) model1 = HMM(K=K,D=D, observation=obs1) T = 100 sample_z, sample_x = model1.sample(T) model2 = HMM(K=K, D=D, observation='gaussian', observation_kwargs=dict(lags=lags)) lls, opt = model2.fit(sample_x, num_iters=2000, lr=0.001) z_infer = model2.most_likely_states(sample_x) x_predict = k_step_prediction(model2, z_infer, sample_x) plt.figure() plt.plot(x_predict[:,0], label='prediction') plt.plot(sample_x[:,0], label='truth')
[3, 8]) # list of length 2, each item is an array (T, 2). T = 36000 rendered_data.append(np.concatenate( (session_data), axis=1)) # each item is an array (T, 4) trajectories = np.concatenate(rendered_data, axis=0) # (T*30, 4) traj29 = rendered_data[29] arena_xmax = 320 arena_ymax = 370 bounds = np.array([[-10, arena_xmax + 10], [-10, arena_ymax + 10], [-10, arena_xmax + 10], [-10, arena_ymax + 10]]) K = 2 D = 4 T = 36000 tran = LinearTransformation(K=K, D=D, lags=1) observation = ARTruncatedNormalObservation(K=K, D=D, M=0, transformation=tran, bounds=bounds) model = HMM(K=K, D=D, M=0, observation=observation) data = torch.tensor(traj29[:5], dtype=torch.float64) out = model.log_likelihood(data) print(out)
data = torch.randn(T, D, dtype=torch.float64) lags = 1 bounds = np.array([[-2, 2], [0, 1], [-2, 2], [0, 1]]) As = np.array([ np.column_stack([np.identity(D), np.zeros((D, (lags - 1) * D))]) for _ in range(K) ]) torch.manual_seed(0) np.random.seed(0) tran = LinearTransformation(K=K, D=D, lags=lags, As=As) observation = ARTruncatedNormalObservation(K=K, D=D, M=0, transformation=tran, bounds=bounds) model = HMM(K=K, D=D, M=0, observation=observation) lls = model.log_likelihood(data) print(lls) #losses_1, optimizer_1 = model_1.fit(data_1, method='adam', num_iters=2000, lr=0.001) z_1 = model.most_likely_states(data) x_predict_arr_lag1 = k_step_prediction(model, z_1, data)
if i % 10 == 0: pbar.set_description('iter {} loss {:.2f}'.format(i, loss)) pbar.update(10) x_reconstruct = model.sample_condition_on_zs(z, data[0]) """ # test momentum_lags true_observation = ARGaussianObservation(K=K, D=D, M=0, lags=5, transformation='linear') true_model = HMM(K=K, D=D, M=0, observation=true_observation) z, data = true_model.sample(T, return_np=True) # fit to a model observation = ARGaussianObservation(K=K, D=D, M=0, lags=5, transformation='linear') model = HMM(K=K, D=D, M=0, observation=observation) num_iters = 10000 pbar = tqdm(total=num_iters, file=sys.stdout)
K = 4 D = 4 momentum_lags = 30 momentum_weights = np.arange(0.55, 2.05, 0.05) T = 36000 observation = MomentumInteractionObservation(K=K, D=D, bounds=bounds, momentum_lags=momentum_lags, momentum_weights=momentum_weights, max_v=max_v) model = HMM(K=K, D=D, M=0, observation=observation) """ ##################### test params ######################## obs2 = CoupledMomentumObservation(K=K, D=D, M=0, momentum_lags=momentum_lags, Df=Df, feature_func=feature_func_single, bounds=bounds) model2 = HMM(K=K, D=D, M=0, observation=obs2) model2.params = model.params for p1, p2 in zip(model.params_unpack, model2.params_unpack): assert torch.all(torch.eq(p1, p2)) """ # precompute features momentum_vecs = MomentumInteractionTransformation._compute_momentum_vecs(
def test_stick_transition(): K = 3 D = 2 M = 0 model = HMM(K=K, D=D, M=M, transition='sticky', observation='gaussian', transition_kwargs=dict(alpha=1, kappa=100)) print("alpha = {}, kappa = {}".format(model.transition.alpha, model.transition.kappa)) data = torch.tensor([[2,3], [4,5], [6, 7]], dtype=torch.float64) log_prob = model.log_likelihood(data) print("log_prob", log_prob) samples_z, samples_x = model.sample(10) samples_log_prob = model.log_probability(samples_x) print("sample log_prob", samples_log_prob) print("model transition\n", model.transition.stationary_transition_matrix) model1 = HMM(K=K, D=D, M=M, transition='sticky', observation='gaussian', transition_kwargs=dict(alpha=1, kappa=0.1)) print("Before fit") print("model1 transition\n", model1.transition.stationary_transition_matrix) losses, opt = model1.fit(samples_x) print("After fit") print("model1 transition\n", model1.transition.stationary_transition_matrix) model2 = HMM(K=K, D=D, M=M, transition='sticky', observation='gaussian', transition_kwargs=dict(alpha=1, kappa=20)) print("Before fit") print("model2 transition\n", model2.transition.stationary_transition_matrix) losses, opt = model2.fit(samples_x) print("After fit") print("model2 transition\n", model2.transition.stationary_transition_matrix) model3 = HMM(K=K, D=D, M=M, transition='sticky', observation='gaussian', transition_kwargs=dict(alpha=1, kappa=100)) print("Before fit") print("model3 transition\n", model3.transition.stationary_transition_matrix) losses, opt = model3.fit(samples_x) print("After fit") print("model3 transition\n", model3.transition.stationary_transition_matrix)
torch.manual_seed(0) np.random.seed(0) K = 3 D = 2 lags = 10 # AR Gaussian trans1 = LinearTransformation(K=K, D=D, lags=lags) obs1 = ARGaussianObservation(K=K, D=D, transformation=trans1, train_sigma=False) model1 = HMM(K=K, D=D, observation=obs1) model2 = HMM(K=K, D=D, observation_kwargs={"lags": lags}) #print(model1.params == model2.params) model2.params = model1.params for p1, p2 in zip(model1.params_unpack, model2.params_unpack): assert torch.all(torch.eq(p1, p2)) # AR LogitNormal bounds = np.array([[0, 2], [0, 4]]) model1 = HMM(K=K,
def main(job_name, lstm_model_type, downsample_n, filter_traj, load_model, load_model_dir, load_opt_dir, transition, sticky_alpha, sticky_kappa, acc_factor, lags, dh, dhs, k, x_grids, y_grids, n_x, n_y, train_model, pbar_update_interval, video_clips, torch_seed, np_seed, list_of_num_iters, list_of_lr, list_of_k_steps, sample_t, quiver_scale): if job_name is None: raise ValueError("Please provide the job name.") assert lstm_model_type in [ "uni", "multi" ], "lstm_model_dypte must be choosen from 'uni' and 'multi'." K = k sample_T = sample_t video_clip_start, video_clip_end = [ float(x) for x in video_clips.split(",") ] if dhs == 'none': dhs = None else: dhs = [int(d_hidden) for d_hidden in dhs.split(",")] list_of_num_iters = [int(x) for x in list_of_num_iters.split(",")] list_of_lr = [float(x) for x in list_of_lr.split(",")] list_of_k_steps = [int(x) for x in list_of_k_steps.split(",")] assert len(list_of_num_iters) == len(list_of_lr), \ "Length of list_of_num_iters must match length of list-of_lr," \ " but we have list_of_num_iters = {}, and list_of_lr = {}".format(list_of_num_iters, list_of_lr) for lr in list_of_lr: if lr > 1: raise ValueError("Learning rate should not be larger than 1!") repo = git.Repo( '.', search_parent_directories=True) # SocialBehaviorectories=True) repo_dir = repo.working_tree_dir # SocialBehavior torch.manual_seed(torch_seed) np.random.seed(np_seed) ########################## data ######################## data_dir = repo_dir + '/SocialBehaviorptc/data/trajs_all' trajs = joblib.load(data_dir) start = int(36000 * video_clip_start) end = int(36000 * video_clip_end) traj = trajs[start:end] traj = downsample(traj, downsample_n) if filter_traj: traj = filter_traj_by_speed(traj, q1=0.99, q2=0.99) data = torch.tensor(traj, dtype=torch.float64) ######################### model #################### # model D = 4 M = 0 Df = 4 if load_model: print("Loading the model from ", load_model_dir) model = joblib.load(load_model_dir) tran = model.observation.transformation assert isinstance(tran, (LSTMTransformation, UniLSTMTransformation)),\ "tran should be {}, but is {}".format(LSTMTransformation, type(tran)) K = model.K acc_factor = tran.acc_factor lstm_model_type = "uni" if isinstance( tran, UniLSTMTransformation) else "multi" else: print("Creating the model...") bounds = np.array([[ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX], [ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX]]) if acc_factor is None: acc_factor = downsample_n * 10 tran = LSTMTransformation(K=K, D=D, Df=Df, feature_vec_func=f_corner_vec_func, lags=lags, dh=dh, dhs=dhs, acc_factor=acc_factor) obs = ARTruncatedNormalObservation(K=K, D=D, M=M, lags=lags, bounds=bounds, transformation=tran) if transition == 'sticky': transition_kwargs = dict(alpha=sticky_alpha, kappa=sticky_kappa) else: transition_kwargs = None model = HMM(K=K, D=D, M=M, transition=transition, observation=obs, transition_kwargs=transition_kwargs) model.observation.mus_init = data[0] * torch.ones( K, D, dtype=torch.float64) # grids if x_grids is None: x_grid_gap = (ARENA_XMAX - ARENA_XMIN) / n_x x_grids = np.array( [ARENA_XMIN + i * x_grid_gap for i in range(n_x + 1)]) else: x_grids = np.array([float(x) for x in x_grids.split(",")]) n_x = len(x_grids) - 1 if y_grids is None: y_grid_gap = (ARENA_YMAX - ARENA_YMIN) / n_y y_grids = np.array( [ARENA_YMIN + i * y_grid_gap for i in range(n_y + 1)]) else: y_grids = np.array([float(x) for x in y_grids.split(",")]) n_y = len(y_grids) - 1 # save experiment params exp_params = { "job_name": job_name, "lstm_model_tyoe": lstm_model_type, 'downsample_n': downsample_n, "filter_traj": filter_traj, "load_model": load_model, "load_model_dir": load_model_dir, "load_opt_dir": load_opt_dir, "transition": transition, "sticky_alpha": sticky_alpha, "sticky_kappa": sticky_kappa, "acc_factor": acc_factor, "K": K, "lags": lags, "dh": dh, "dhs": dhs, "n_x": n_x, "n_y": n_y, "x_grids": x_grids, "y_grids": y_grids, "train_model": train_model, "pbar_update_interval": pbar_update_interval, "list_of_num_iters": list_of_num_iters, "list_of_lr": list_of_lr, "video_clip_start": video_clip_start, "video_clip_end": video_clip_end, "list_of_k_steps": list_of_k_steps, "sample_T": sample_T } print("Experiment params:") print(exp_params) rslt_dir = addDateTime("rslts/lstm/{}/{}".format(lstm_model_type, job_name)) rslt_dir = os.path.join(repo_dir, rslt_dir) if not os.path.exists(rslt_dir): os.makedirs(rslt_dir) print("Making result directory...") print("Saving to rlst_dir: ", rslt_dir) with open(rslt_dir + "/exp_params.json", "w") as f: json.dump(exp_params, f, indent=4, cls=NumpyEncoder) # compute memory print("Computing memory...") feature_vecs_a = f_corner_vec_func(data[:-1, 0:2]) # (T, Df, 2) feature_vecs_b = f_corner_vec_func(data[:-1, 2:4]) # (T, Df, 2) feature_vecs = (feature_vecs_a, feature_vecs_b) packed_data = get_packed_data(data[:-1], lags=lags) memory_kwargs = dict(packed_data=packed_data, feature_vecs=feature_vecs) ##################### training ############################ if train_model: print("start training") list_of_losses = [] if load_opt_dir != "": opt = joblib.load(load_opt_dir) else: opt = None for i, (num_iters, lr) in enumerate(zip(list_of_num_iters, list_of_lr)): losses, opt = model.fit(data, optimizer=opt, method='adam', num_iters=num_iters, lr=lr, pbar_update_interval=pbar_update_interval, **memory_kwargs) list_of_losses.append(losses) checkpoint_dir = rslt_dir + "/checkpoint_{}".format(i) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) print("Creating checkpoint_{} directory...".format(i)) # save model and opt joblib.dump(model, checkpoint_dir + "/model") joblib.dump(opt, checkpoint_dir + "/optimizer") # save rest rslt_saving(checkpoint_dir, model, data, memory_kwargs, list_of_k_steps, sample_T, train_model, losses, quiver_scale, x_grids=x_grids, y_grids=y_grids) else: # only save the results rslt_saving(rslt_dir, model, data, memory_kwargs, list_of_k_steps, sample_T, False, [], quiver_scale, x_grids=x_grids, y_grids=y_grids) print("Finish running!")