def k_step_prediction_for_lineargrid_model(model, model_z, data, **kwargs): if len(data) == 0: return None data = check_and_convert_to_tensor(data) _, D = data.shape assert D == 2 or D == 4 if D == 4: feature_vecs = kwargs.get("feature_vecs", None) gridpoints = kwargs.get("gridpoints", None) gridpoints_idx = kwargs.get("gridpoints_idx", None) if feature_vecs is None or gridpoints_idx is None or gridpoints is None: print("Did not provide memory information") return k_step_prediction(model, model_z, data) else: grid_points_idx_a, grid_points_idx_b = gridpoints_idx gridpoints_a, gridpoints_b = gridpoints feature_vecs_a, feature_vecs_b = feature_vecs x_predict_arr = [] x_predict = model.observation.sample_x(model_z[0], data[:0], return_np=True, with_noise=True) x_predict_arr.append(x_predict) for t in range(1, data.shape[0]): x_predict = model.observation.sample_x( model_z[t], data[t - 1:t], return_np=True, with_noise=True, gridpoints=(gridpoints_a[t - 1], gridpoints_b[t - 1]), gridpoints_idx=(grid_points_idx_a[t - 1], grid_points_idx_b[t - 1]), feature_vec=(feature_vecs_a[t - 1:t], feature_vecs_b[t - 1:t])) x_predict_arr.append(x_predict) x_predict_arr = np.array(x_predict_arr) else: coeffs = kwargs.get("coeffs", None) gridpoints_idx = kwargs.get("gridpoints_idx", None) if coeffs is None or gridpoints_idx is None: print("Did not provide memory ") return k_step_prediction(model, model_z, data) else: x_predict_arr = [] x_predict = model.observation.sample_x(model_z[0], data[:0], return_np=True, with_noise=True) x_predict_arr.append(x_predict) for t in range(1, data.shape[0]): x_predict = model.observation.sample_x( model_z[t], data[t - 1:t], return_np=True, with_noise=True, coeffs=coeffs[t - 1:t], gridpoints_idx=gridpoints_idx[t - 1]) return x_predict_arr
def k_step_prediction_for_lstm_model(model, model_z, data, feature_vecs=None): data = check_and_convert_to_tensor(data) if feature_vecs is None: print("Did not provide memory information") return k_step_prediction(model, model_z, data) else: feature_vecs_a, feature_vecs_b = feature_vecs x_predict_arr = [] x_predict = model.observation.sample_x(model_z[0], data[:0], return_np=True) x_predict_arr.append(x_predict) for t in range(1, data.shape[0]): feature_vec_t = (feature_vecs_a[t - 1:t], feature_vecs_b[t - 1:t]) x_predict = model.observation.sample_x(model_z[t], data[:t], return_np=True, with_noise=True, feature_vec=feature_vec_t) x_predict_arr.append(x_predict) x_predict_arr = np.array(x_predict_arr) return x_predict_arr
def k_step_prediction_for_momentum_feature_model(model, model_z, data, momentum_vecs=None, features=None): data = check_and_convert_to_tensor(data) if momentum_vecs is None or features is None: return k_step_prediction(model, model_z, data) else: x_predict_arr = [] x_predict = model.observation.sample_x(model_z[0], data[:0], return_np=True) x_predict_arr.append(x_predict) for t in range(1, data.shape[0]): x_predict = model.observation.sample_x( model_z[t], data[:t], return_np=True, with_noise=True, momentum_vec=momentum_vecs[t - 1], features=(features[0][t - 1], features[1][t - 1])) x_predict_arr.append(x_predict) x_predict_arr = np.array(x_predict_arr) return x_predict_arr
def k_step_prediction_for_gpgrid_model(model, model_z, data, **memory_kwargs): data = check_and_convert_to_tensor(data) if memory_kwargs == {}: print("Did not provide memory information") return k_step_prediction(model, model_z, data) else: feature_vecs_a = memory_kwargs.get("feature_vecs_a", None) feature_vecs_b = memory_kwargs.get("feature_vecs_b", None) gpt_idx_a = memory_kwargs.get("gpt_idx_a", None) gpt_idx_b = memory_kwargs.get("gpt_idx_b", None) grid_idx_a = memory_kwargs.get("grid_idx_a", None) grid_idx_b = memory_kwargs.get("grid_idx_b") coeff_a = memory_kwargs.get("coeff_a", None) coeff_b = memory_kwargs.get("coeff_b", None) dist_sq_a = memory_kwargs.get("dist_sq_a", None) dist_sq_b = memory_kwargs.get("dist_sq_b", None) x_predict_arr = [] x_predict = model.observation.sample_x(model_z[0], data[:0], return_np=True) x_predict_arr.append(x_predict) for t in range(1, data.shape[0]): if dist_sq_a is None: x_predict = model.observation.sample_x( model_z[t], data[:t], return_np=True, with_noise=True, feature_vec_a=feature_vecs_a[t - 1:t], feature_vec_b=feature_vecs_b[t - 1:t], gpt_idx_a=gpt_idx_a[t - 1:t], gpt_idx_b=gpt_idx_b[t - 1:t], grid_idx_a=grid_idx_a[t - 1:t], grid_idx_b=grid_idx_b[t - 1:t], coeff_a=coeff_a[t - 1:t], coeff_b=coeff_b[t - 1:t]) else: x_predict = model.observation.sample_x( model_z[t], data[:t], return_np=True, with_noise=True, feature_vec_a=feature_vecs_a[t - 1:t], feature_vec_b=feature_vecs_b[t - 1:t], gpt_idx_a=gpt_idx_a[t - 1:t], gpt_idx_b=gpt_idx_b[t - 1:t], grid_idx_a=grid_idx_a[t - 1:t], grid_idx_b=grid_idx_b[t - 1:t], dist_sq_a=dist_sq_a[t - 1:t], dist_sq_b=dist_sq_b[t - 1:t]) x_predict_arr.append(x_predict) x_predict_arr = np.array(x_predict_arr) return x_predict_arr
def k_step_prediction_for_grid_model(model, model_z, data, **memory_kwargs): if len(data) == 0: return None data = check_and_convert_to_tensor(data) memory_kwargs_a = memory_kwargs.get("memory_kwargs_a", None) memory_kwargs_b = memory_kwargs.get("memory_kwargs_b", None) if memory_kwargs_a is None or memory_kwargs_b is None: print("Did not provide memory information") return k_step_prediction(model, model_z, data) else: momentum_vecs_a = memory_kwargs_a.get("momentum_vecs", None) feature_vecs_a = memory_kwargs_a.get("feature_vecs", None) momentum_vecs_b = memory_kwargs_b.get("momentum_vecs", None) feature_vecs_b = memory_kwargs_b.get("feature_vecs", None) x_predict_arr = [] x_predict = model.observation.sample_x(model_z[0], data[:0], return_np=True) x_predict_arr.append(x_predict) for t in range(1, data.shape[0]): if momentum_vecs_a is None: m_kwargs_a = dict(feature_vec=feature_vecs_a[t - 1]) m_kwargs_b = dict(feature_vec=feature_vecs_b[t - 1]) else: m_kwargs_a = dict(momentum_vec=momentum_vecs_a[t - 1], feature_vec=feature_vecs_a[t - 1]) m_kwargs_b = dict(momentum_vec=momentum_vecs_b[t - 1], feature_vec=feature_vecs_b[t - 1]) x_predict = model.observation.sample_x(model_z[t], data[:t], return_np=True, with_noise=True, memory_kwargs_a=m_kwargs_a, memory_kwargs_b=m_kwargs_b) x_predict_arr.append(x_predict) x_predict_arr = np.array(x_predict_arr) return x_predict_arr
import matplotlib.pyplot as plt torch.manual_seed(0) np.random.seed(0) # test fitting K = 3 D = 2 lags= 5 trans1 = LinearTransformation(K=K, D=D, lags=lags) obs1 = ARGaussianObservation(K=K, D=D, transformation=trans1) model1 = HMM(K=K,D=D, observation=obs1) T = 100 sample_z, sample_x = model1.sample(T) model2 = HMM(K=K, D=D, observation='gaussian', observation_kwargs=dict(lags=lags)) lls, opt = model2.fit(sample_x, num_iters=2000, lr=0.001) z_infer = model2.most_likely_states(sample_x) x_predict = k_step_prediction(model2, z_infer, sample_x) plt.figure() plt.plot(x_predict[:,0], label='prediction') plt.plot(sample_x[:,0], label='truth') plt.show()
def k_step_prediction_for_gpmodel(model, model_z, data, **memory_kwargs): data = check_and_convert_to_tensor(data) T, D = data.shape assert D == 4 or D == 2, D K = model.observation.K if memory_kwargs == {}: print("Did not provide memory information") return k_step_prediction(model, model_z, data) else: # compute As if D == 4: _, A_a = model.observation.get_gp_cache(data[:-1, 0:2], 0, A_only=True, **memory_kwargs) _, A_b = model.observation.get_gp_cache(data[:-1, 2:4], 1, A_only=True, **memory_kwargs) assert A_a.shape == A_b.shape == ( T - 1, K, 2, model.observation.n_gps * 2), "{}, {}".format( A_a.shape, A_b.shape) x_predict_arr = [] x_predict = model.observation.sample_x(model_z[0], data[:0], return_np=True) x_predict_arr.append(x_predict) for t in range(1, data.shape[0]): x_predict = model.observation.sample_x(model_z[t], data[:t], return_np=True, with_noise=True, A_a=A_a[t - 1:t, model_z[t]], A_b=A_b[t - 1:t, model_z[t]]) x_predict_arr.append(x_predict) else: _, A = model.observation.get_gp_cache(data[:-1], A_only=True, **memory_kwargs) x_predict_arr = [] x_predict = model.observation.sample_x(model_z[0], data[:0], return_np=True) x_predict_arr.append(x_predict) for t in range(1, data.shape[0]): x_predict = model.observation.sample_x(model_z[t], data[:t], return_np=True, with_noise=True, A=(A[0][model_z[t], t - 1:t], A[1][model_z[t], t - 1:t])) x_predict_arr.append(x_predict) x_predict_arr = np.array(x_predict_arr) return x_predict_arr
def rslt_saving(rslt_dir, model, data, animal, memory_kwargs, list_of_k_steps, sample_T, quiver_scale, x_grids=None, y_grids=None, dynamics_T=None, valid_data=None, valid_data_memory_kwargs=None, device=torch.device('cpu')): valid_data_memory_kwargs = valid_data_memory_kwargs if valid_data_memory_kwargs else {} tran = model.observation.transformation if x_grids is None or y_grids is None: x_grids = tran.x_grids y_grids = tran.y_grids n_x = len(x_grids) - 1 n_y = len(y_grids) - 1 K = model.K memory_kwargs = memory_kwargs if memory_kwargs else {} #################### inference ########################### print("\ninferring most likely states...") z = model.most_likely_states(data, **memory_kwargs) z_valid = model.most_likely_states(valid_data, **valid_data_memory_kwargs) # TODO: address valida_data = None print("0 step prediction") # TODO: add valid data for other model if data.shape[0] <= 10000: data_to_predict = data else: data_to_predict = data[-10000:] if isinstance(tran, (LinearGridTransformation, SingleLinearGridTransformation)): x_predict = k_step_prediction_for_lineargrid_model(model, z, data_to_predict, **memory_kwargs) x_predict_valid = k_step_prediction_for_lineargrid_model(model, z_valid, valid_data, **valid_data_memory_kwargs) elif isinstance(tran, GPGridTransformation): x_predict = k_step_prediction_for_gpgrid_model(model, z, data_to_predict, **memory_kwargs) x_predict_valid = k_step_prediction_for_gpgrid_model(model, z_valid, valid_data, **valid_data_memory_kwargs) elif isinstance(tran, WeightedGridTransformation): x_predict = k_step_prediction_for_weightedgrid_model(model, z, data_to_predict, **memory_kwargs) x_predict_valid = k_step_prediction_for_weightedgrid_model(model, z_valid, valid_data, **valid_data_memory_kwargs) elif isinstance(tran, (LSTMTransformation, UniLSTMTransformation)): x_predict = k_step_prediction_for_lstm_model(model, z, data_to_predict, feature_vecs=memory_kwargs["feature_vecs"]) x_predict_valid = k_step_prediction_for_lstm_model(model, z_valid, valid_data, feature_vecs=valid_data_memory_kwargs["feature_vecs"]) elif isinstance(tran, LSTMBasedTransformation): x_predict = k_step_prediction_for_lstm_based_model(model, z, data_to_predict, k=0, feature_vecs=memory_kwargs["feature_vecs"]) x_predict_valid = k_step_prediction_for_lstm_based_model(model, z_valid, valid_data, k=0, feature_vecs=valid_data_memory_kwargs["feature_vecs"]) else: raise ValueError("Unsupported transformation!") x_predict_err = np.mean(np.abs(x_predict - get_np(data_to_predict)), axis=0) if len(valid_data) == 0: x_predict_valid_err = None else: x_predict_valid_err = np.mean(np.abs(x_predict_valid - get_np(valid_data)), axis=0) dict_of_x_predict_k = dict(x_predict_0=x_predict, x_predict_v_0=x_predict_valid) dict_of_x_predict_k_err = dict(x_predict_0_err=x_predict_err, x_predict_v_0_err=x_predict_valid_err) for k_step in list_of_k_steps: print("{} step prediction".format(k_step)) if isinstance(tran, LSTMBasedTransformation): # TODO: take care of empty valid data x_predict_k = k_step_prediction_for_lstm_based_model(model, z, data_to_predict, k=k_step) x_predict_valid_k = k_step_prediction_for_lstm_model(model, z, data_to_predict, k=k_step) else: x_predict_k = k_step_prediction(model, z, data_to_predict, k=k_step) x_predict_valid_k = k_step_prediction(model, z_valid, valid_data, k=k_step) x_predict_k_err = np.mean(np.abs(x_predict_k - get_np(data_to_predict[k_step:])), axis=0) if len(valid_data) == 0: x_predict_valid_k_err = None else: x_predict_valid_k_err = np.mean(np.abs(x_predict_valid_k - get_np(valid_data[k_step:])), axis=0) dict_of_x_predict_k["x_predict_{}".format(k_step)] = x_predict_k dict_of_x_predict_k["x_predict_v_{}".format(k_step)] = x_predict_valid_k dict_of_x_predict_k_err["x_predict_{}_err".format(k_step)] = x_predict_k_err dict_of_x_predict_k_err["x_predict_v_{}_err".format(k_step)] = x_predict_valid_k_err ################### samples ######################### print("sampling") center_z = torch.tensor([0], dtype=torch.int, device=device) if animal == "both": center_x = torch.tensor([[150, 190, 200, 200]], dtype=torch.float64, device=device) else: center_x = torch.tensor([[150, 190]], dtype=torch.float64, device=device) if isinstance(tran, LSTMBasedTransformation): lstm_states = {} sample_z, sample_x = model.sample(sample_T, lstm_states=lstm_states) lstm_states = {} sample_z_center, sample_x_center = model.sample(sample_T, prefix=(center_z, center_x), lstm_states=lstm_states) else: sample_z, sample_x = model.sample(sample_T) sample_z_center, sample_x_center = model.sample(sample_T, prefix=(center_z, center_x)) ################## dynamics ##################### if isinstance(tran, (LinearGridTransformation, SingleLinearGridTransformation, GPGridTransformation, WeightedGridTransformation)): # quiver if animal == 'both': XX, YY = np.meshgrid(np.linspace(20, 310, 30), np.linspace(0, 380, 30)) XY = np.column_stack((np.ravel(XX), np.ravel(YY))) # shape (900,2) grid values XY_grids = np.concatenate((XY, XY), axis=1) # (900, 4) else: XX, YY = np.meshgrid(np.linspace(20, 310, 30), np.linspace(0, 380, 30)) XY_grids = np.column_stack((np.ravel(XX), np.ravel(YY))) # shape (900,2) grid values XY_next = tran.transform(torch.tensor(XY_grids, dtype=torch.float64, device=device)) dXY = get_np(XY_next) - XY_grids[:, None] # TODO: maybe use sample condition on z (transformation) to show the dynamics samples_on_fixed_zs = [] if isinstance(tran, LSTMBasedTransformation): assert dynamics_T is not None for k in range(K): lstm_states = {} fixed_z = torch.ones(dynamics_T, dtype=torch.int) * k samples_on_fixed_z = model.sample_condition_on_zs(zs=fixed_z, transformation=True, return_np=True, lstm_states=lstm_states) samples_on_fixed_zs.append(samples_on_fixed_z) #################### saving ############################## print("begin saving...") # save summary if isinstance(tran, (LinearGridTransformation, SingleLinearGridTransformation, GPGridTransformation, WeightedGridTransformation)): avg_transform_speed = np.average(np.abs(dXY), axis=0) avg_sample_speed = np.average(np.abs(np.diff(sample_x, axis=0)), axis=0) avg_sample_center_speed = np.average(np.abs(np.diff(sample_x_center, axis=0)), axis=0) avg_data_speed = np.average(np.abs(np.diff(get_np(data), axis=0)), axis=0) if isinstance(model.transition, StationaryTransition): transition_matrix = model.transition.stationary_transition_matrix elif isinstance(model.transition, GridTransition): transition_matrix = model.transition.grid_transition_matrix else: raise ValueError("unsupported transition matrix type: {}".format(type(model.transition))) transition_matrix = get_np(transition_matrix) summary_dict = {"init_dist": get_np(model.init_dist), "transition_matrix": transition_matrix, "variance": get_np(torch.exp(model.observation.log_sigmas)), "log_likes": get_np(model.log_likelihood(data, **memory_kwargs)), "avg_data_speed": avg_data_speed, "avg_sample_speed": avg_sample_speed, "avg_sample_center_speed": avg_sample_center_speed} summary_dict = {**dict_of_x_predict_k_err, **summary_dict} if len(valid_data) > 0: summary_dict["valid_log_likes"] = get_np(model.log_likelihood(valid_data, **valid_data_memory_kwargs)) if isinstance(tran, GPGridTransformation): summary_dict["real_rs"] = get_np(tran.rs_factor * torch.sigmoid(tran.rs)) if isinstance(tran, WeightedGridTransformation): summary_dict["beta"] = get_np(tran.beta) if isinstance(tran, (LinearGridTransformation, GPGridTransformation, WeightedGridTransformation)): summary_dict["avg_transform_speed"] = avg_transform_speed with open(rslt_dir + "/summary.json", "w") as f: json.dump(summary_dict, f, indent=4, cls=NumpyEncoder) # save numbers saving_dict = {"z": z, "z_valid": z_valid, "sample_z": sample_z, "sample_x": sample_x, "sample_z_center": sample_z_center, "sample_x_center": sample_x_center} saving_dict = {**dict_of_x_predict_k, **saving_dict} if isinstance(tran, LSTMBasedTransformation): saving_dict["samples_on_fixed_zs"] = samples_on_fixed_zs joblib.dump(saving_dict, rslt_dir + "/numbers") # save figures if model.D == 2 and isinstance(model.transition, GridTransition): plot_grid_transition(n_x, n_y, model.transition.grid_transition_matrix) plt.savefig(rslt_dir + "/grid_transition.jpg") plt.close() plot_z(z, K, title="most likely z for the ground truth") plt.savefig(rslt_dir + "/z.jpg") plt.close() if len(valid_data) >0: plot_z(z_valid, K, title="most likely z for valid data") plt.savefig(rslt_dir + "/z_valid.jpg") plt.close() if not os.path.exists(rslt_dir + "/samples"): os.makedirs(rslt_dir + "/samples") print("Making samples directory...") plot_z(sample_z, K, title="sample") plt.savefig(rslt_dir + "/samples/sample_z_{}.jpg".format(sample_T)) plt.close() plot_z(sample_z_center, K, title="sample (starting from center)") plt.savefig(rslt_dir + "/samples/sample_z_center_{}.jpg".format(sample_T)) plt.close() plt.figure(figsize=(4, 4)) plot_mouse(data, title="ground truth (training)", xlim=[ARENA_XMIN - 20, ARENA_YMAX + 20], ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20]) plt.legend() plt.savefig(rslt_dir + "/samples/ground_truth.jpg") plt.close() if len(valid_data) > 0: plt.figure(figsize=(4, 4)) plot_mouse(valid_data, title="ground truth (valid)", xlim=[ARENA_XMIN - 20, ARENA_YMAX + 20], ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20]) plt.legend() plt.savefig(rslt_dir + "/samples/ground_truth_valid.jpg") plt.close() plt.figure(figsize=(4, 4)) plot_mouse(sample_x, title="sample", xlim=[ARENA_XMIN - 20, ARENA_YMAX + 20], ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20]) plt.legend() plt.savefig(rslt_dir + "/samples/sample_x_{}.jpg".format(sample_T)) plt.close() plt.figure(figsize=(4, 4)) plot_mouse(sample_x_center, title="sample (starting from center)", xlim=[ARENA_XMIN - 20, ARENA_YMAX + 20], ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20]) plt.legend() plt.savefig(rslt_dir + "/samples/sample_x_center_{}.jpg".format(sample_T)) plt.close() plot_realdata_quiver(data, z, K, x_grids, y_grids, title="ground truth (training)") plt.savefig(rslt_dir + "/samples/quiver_ground_truth.jpg", dpi=200) plt.close() if len(valid_data) > 0: plot_realdata_quiver(valid_data, z_valid, K, x_grids, y_grids, title="ground truth (valid)") plt.savefig(rslt_dir + "/samples/quiver_ground_truth_valid.jpg", dpi=200) plt.close() plot_realdata_quiver(sample_x, sample_z, K, x_grids, y_grids, title="sample") plt.savefig(rslt_dir + "/samples/quiver_sample_x_{}.jpg".format(sample_T), dpi=200) plt.close() plot_realdata_quiver(sample_x_center, sample_z_center, K, x_grids, y_grids, title="sample (starting from center)") plt.savefig(rslt_dir + "/samples/quiver_sample_x_center_{}.jpg".format(sample_T), dpi=200) plt.close() if isinstance(tran, (LinearGridTransformation, SingleLinearGridTransformation, GPGridTransformation, WeightedGridTransformation)): if not os.path.exists(rslt_dir + "/dynamics"): os.makedirs(rslt_dir + "/dynamics") print("Making dynamics directory...") if animal == 'both': plot_quiver(XY_grids[:, 0:2], dXY[..., 0:2], 'virgin', K=K, scale=quiver_scale, alpha=0.9, title="quiver (virgin)", x_grids=x_grids, y_grids=y_grids, grid_alpha=0.2) plt.savefig(rslt_dir + "/dynamics/quiver_a.jpg", dpi=200) plt.close() plot_quiver(XY_grids[:, 2:4], dXY[..., 2:4], 'mother', K=K, scale=quiver_scale, alpha=0.9, title="quiver (mother)", x_grids=x_grids, y_grids=y_grids, grid_alpha=0.2) plt.savefig(rslt_dir + "/dynamics/quiver_b.jpg", dpi=200) plt.close() else: plot_quiver(XY_grids, dXY, animal, K=K, scale=quiver_scale, alpha=0.9, title="quiver ({})".format(animal), x_grids=x_grids, y_grids=y_grids, grid_alpha=0.2) plt.savefig(rslt_dir + "/dynamics/quiver_{}.jpg".format(animal), dpi=200) plt.close() elif isinstance(tran, LSTMBasedTransformation): if not os.path.exists(rslt_dir + "/dynamics"): os.makedirs(rslt_dir + "/dynamics") print("Making dynamics directory...") for k in range(K): plot_realdata_quiver(samples_on_fixed_zs[k], np.ones(dynamics_T, dtype=np.int)*k, K, x_grids=x_grids, y_grids=y_grids, title="sample conditioned on k={}".format(k)) plt.savefig(rslt_dir + "/dynamics/samples_on_k{}.jpg".format(k), dpi=200) plt.close() if not os.path.exists(rslt_dir + "/distributions"): os.makedirs(rslt_dir + "/distributions") print("Making distributions directory...") # sanity checks plot_data_condition_on_all_zs(data, z, K, size=2, alpha=0.3) plt.savefig(rslt_dir + "/distributions/spatial_occup_groundtruth.jpg", dpi=100) plot_data_condition_on_all_zs(sample_x, sample_z, K, size=2, alpha=0.3) plt.savefig(rslt_dir + "/distributions/spatial_occup_sample_x.jpg", dpi=100) plot_data_condition_on_all_zs(sample_x_center, sample_z_center, K, size=2, alpha=0.3) plt.savefig(rslt_dir + "/distributions/spatial_occup_sample_x_center.jpg", dpi=100) plot_2d_time_plot_condition_on_all_zs(data, z, K, title='ground truth') plt.savefig(rslt_dir + "/distributions/4traces_groundtruth.jpg", dpi=100) plot_2d_time_plot_condition_on_all_zs(sample_x, sample_z, K, title='sample_x') plt.savefig(rslt_dir + "/distributions/4traces_sample_x.jpg", dpi=100) plot_2d_time_plot_condition_on_all_zs(sample_x_center, sample_z_center, K, title='sample_x_center') plt.savefig(rslt_dir + "/distributions/4traces_sample_x_center.jpg", dpi=100) data_angles = get_all_angles(data, x_grids, y_grids, device=device) sample_angles = get_all_angles(sample_x, x_grids, y_grids, device=device) sample_x_center_angles = get_all_angles(sample_x_center, x_grids, y_grids, device=device) if animal == 'both': plot_list_of_angles([data_angles[0], sample_angles[0], sample_x_center_angles], ['data', 'sample', 'sample_c'], "direction distribution (virgin)", n_x, n_y) plt.savefig(rslt_dir + "/distributions/angles_a.jpg") plt.close() plot_list_of_angles([data_angles[1], sample_angles[1], sample_x_center_angles], ['data', 'sample', 'sample_c'], "direction distribution (mother)", n_x, n_y) plt.savefig(rslt_dir + "/distributions/angles_b.jpg") plt.close() else: plot_list_of_angles([data_angles, sample_angles, sample_x_center_angles], ['data', 'sample', 'sample_c'], "direction distribution ({})".format(animal), n_x, n_y) plt.savefig(rslt_dir + "/distributions/angles_{}.jpg".format(animal)) plt.close() data_speed = get_speed(data, x_grids, y_grids, device=device) sample_speed = get_speed(sample_x, x_grids, y_grids, device=device) sample_x_center_speed = get_speed(sample_x_center, x_grids, y_grids, device=device) if animal == 'both': plot_list_of_speed([data_speed[0], sample_speed[0], sample_x_center_speed[0]], ['data', 'sample', 'sample_c'], "speed distribution (virgin)", n_x, n_y) plt.savefig(rslt_dir + "/distributions/speed_a.jpg") plt.close() plot_list_of_speed([data_speed[1], sample_speed[1], sample_x_center_speed[1]], ['data', 'sample', 'sample_c'], "speed distribution (mother)", n_x, n_y) plt.savefig(rslt_dir + "/distributions/speed_b.jpg") plt.close() else: plot_list_of_speed([data_speed, sample_speed, sample_x_center_speed], ['data', 'sample', 'sample_c'], "speed distribution ({})".format(animal), n_x, n_y) plt.savefig(rslt_dir + "/distributions/speed_{}.jpg".format(animal)) plt.close() try: if 100 < data.shape[0] <= 36000: plot_space_dist(data, x_grids, y_grids) elif data.shape[0] > 36000: plot_space_dist(data[:36000], x_grids, y_grids) plt.savefig(rslt_dir + "/distributions/space_data.jpg") plt.close() if 100 < sample_x.shape[0] <= 36000: plot_space_dist(sample_x, x_grids, y_grids) elif sample_x.shape[0] > 36000: plot_space_dist(sample_x[:36000], x_grids, y_grids) plt.savefig(rslt_dir + "/distributions/space_sample_x.jpg") plt.close() if 100 < sample_x_center.shape[0] <= 36000: plot_space_dist(sample_x_center, x_grids, y_grids) elif sample_x_center.shape[0] > 36000: plot_space_dist(sample_x_center[:36000], x_grids, y_grids) plt.savefig(rslt_dir + "/distributions/space_sample_x_center.jpg") plt.close() except: print("plot_space_dist unsuccessful")
print(out) ##################### training ############################ num_iters = 10 losses, opt = model.fit(data, num_iters=num_iters, lr=0.001, momentum_vecs=momentum_vecs, interaction_vecs=interaction_vecs) ##################### sampling ############################ print("start sampling") sample_z, sample_x = model.sample(30) #################### inference ########################### print("inferiring most likely states...") z = model.most_likely_states(data, momentum_vecs=momentum_vecs, interaction_vecs=interaction_vecs) print("k step prediction") x_predict = k_step_prediction_for_momentum_interaction_model( model, z, data, momentum_vecs=momentum_vecs, interaction_vecs=interaction_vecs) print("k step prediction without precomputed features.") x_predict_2 = k_step_prediction(model, z, data, 10)
def rslt_saving(rslt_dir, model, data, animal, memory_kwargs, list_of_k_steps, sample_T, quiver_scale, x_grids=None, y_grids=None, valid_data=None, transition_memory_kwargs=None, valid_data_transition_memory_kwargs = None, valid_data_memory_kwargs=None, device=torch.device('cpu')): transition_memory_kwargs = transition_memory_kwargs if transition_memory_kwargs else {} valid_data_transition_memory_kwargs = \ valid_data_transition_memory_kwargs if valid_data_transition_memory_kwargs else {} valid_data_memory_kwargs = valid_data_memory_kwargs if valid_data_memory_kwargs else {} obs = model.observation if animal == 'both': assert isinstance(obs, GPObservation), type(obs) else: assert isinstance(obs, GPObservationSingle), type(obs) if x_grids is None or y_grids is None: x_grids = obs.x_grids y_grids = obs.y_grids n_x = len(x_grids) - 1 n_y = len(y_grids) - 1 K = model.K memory_kwargs = memory_kwargs if memory_kwargs else {} valid_data_memory_kwargs = valid_data_memory_kwargs if valid_data_memory_kwargs else {} #################### inference ########################### print("\ninferring most likely states...") z = model.most_likely_states(data, transition_mkwargs=transition_memory_kwargs, **memory_kwargs) z_valid = model.most_likely_states(valid_data, transition_mkwargs=valid_data_transition_memory_kwargs, **valid_data_memory_kwargs) # TODO: address valida_data = None print("0 step prediction") # TODO: fix kwargs not matching error data_to_predict = data x_predict = k_step_prediction_for_gpmodel(model, z, data_to_predict, **memory_kwargs) x_predict_valid = k_step_prediction_for_gpmodel(model, z_valid, valid_data, **valid_data_memory_kwargs) x_predict_err = np.mean(np.abs(x_predict - get_np(data_to_predict)), axis=0) if len(valid_data) == 0: x_predict_valid_err = None else: x_predict_valid_err = np.mean(np.abs(x_predict_valid - get_np(valid_data)), axis=0) dict_of_x_predict_k = dict(x_predict_0=x_predict, x_predict_v_0=x_predict_valid) dict_of_x_predict_k_err = dict(x_predict_0_err=x_predict_err, x_predict_v_0_err=x_predict_valid_err) for k_step in list_of_k_steps: print("{} step prediction".format(k_step)) x_predict_k = k_step_prediction(model, z, data_to_predict, k=k_step) x_predict_valid_k = k_step_prediction(model, z, data_to_predict, k=k_step) x_predict_k_err = np.mean(np.abs(x_predict_k - get_np(data_to_predict[k_step:])), axis=0) if len(valid_data) == 0: x_predict_valid_k_err = None else: x_predict_valid_k_err = np.mean(np.abs(x_predict_valid_k - get_np(valid_data[k_step:])), axis=0) dict_of_x_predict_k["x_predict_{}".format(k_step)] = x_predict_k dict_of_x_predict_k["x_predict_v_{}".format(k_step)] = x_predict_valid_k dict_of_x_predict_k_err["x_predict_{}_err".format(k_step)] = x_predict_k_err dict_of_x_predict_k_err["x_predict_v_{}_err".format(k_step)] = x_predict_valid_k_err ################### samples ######################### print("sampling") center_z = torch.tensor([0], dtype=torch.int, device=device) if animal == 'both': center_x = torch.tensor([[150, 190, 200, 200]], dtype=torch.float64, device=device) else: center_x = torch.tensor([[150, 190]], dtype=torch.float64, device=device) sample_z, sample_x = model.sample(sample_T) sample_z_center, sample_x_center = model.sample(sample_T, prefix=(center_z, center_x)) ################## dynamics ##################### print("dynamics") # quiver XX, YY = np.meshgrid(np.linspace(20, 310, 30), np.linspace(0, 380, 30)) XY_grids = np.column_stack((np.ravel(XX), np.ravel(YY))) # shape (900,2) grid values if animal == 'both': XY_grids = np.concatenate((XY_grids, XY_grids), axis=1) # (900, 4) # TODO, fix not based on z if animal == 'both': XY_next_a, _ = obs.get_mu_and_cov_for_single_animal(XY_grids[:,0:2], 0, mu_only=True) XY_next_b, _ = obs.get_mu_and_cov_for_single_animal(XY_grids[:,2:4], 1, mu_only=True) XY_next = torch.cat((XY_next_a, XY_next_b), dim=-1) dXY = get_np(XY_next) - XY_grids[:, None] else: XY_next, _ = obs.get_mu_and_cov_for_single_animal(XY_grids, mu_only=True) dXY = get_np(XY_next) - XY_grids[:, None] #################### saving ############################## print("begin saving...") # save summary avg_transform_speed = np.average(np.abs(dXY), axis=0) avg_sample_speed = np.average(np.abs(np.diff(sample_x, axis=0)), axis=0) avg_sample_center_speed = np.average(np.abs(np.diff(sample_x_center, axis=0)), axis=0) avg_data_speed = np.average(np.abs(np.diff(get_np(data), axis=0)), axis=0) if isinstance(model.transition, StationaryTransition): transition_matrix = model.transition.stationary_transition_matrix elif isinstance(model.transition, GridTransition): transition_matrix = model.transition.grid_transition_matrix else: raise ValueError("unsupported transition matrix type: {}".format(type(model.transition))) transition_matrix = get_np(transition_matrix) summary_dict = {"init_dist": get_np(model.init_dist), "transition_matrix": transition_matrix, "variance": get_np(torch.exp(model.observation.log_sigmas)), "log_likes": get_np(model.log_likelihood(data, **memory_kwargs)), "avg_data_speed": avg_data_speed, "avg_sample_speed": avg_sample_speed, "avg_sample_center_speed": avg_sample_center_speed} summary_dict = {**dict_of_x_predict_k_err, **summary_dict} if len(valid_data) > 0: summary_dict["valid_log_likes"] = get_np(model.log_likelihood(valid_data, **valid_data_memory_kwargs)) summary_dict["rs"] = get_np(model.observation.rs) summary_dict["avg_transform_speed"] = avg_transform_speed with open(rslt_dir + "/summary.json", "w") as f: json.dump(summary_dict, f, indent=4, cls=NumpyEncoder) # save numbers saving_dict = {"z": z, "z_valid": z_valid, "sample_z": sample_z, "sample_x": sample_x, "sample_z_center": sample_z_center, "sample_x_center": sample_x_center} saving_dict = {**dict_of_x_predict_k, **saving_dict} joblib.dump(saving_dict, rslt_dir + "/numbers") # save figures if model.D == 2 and isinstance(model.transition, GridTransition): plot_grid_transition(n_x, n_y, model.transition.grid_transition_matrix) plt.savefig(rslt_dir + "/grid_transition.jpg") plt.close() if not os.path.exists(rslt_dir + "/samples"): os.makedirs(rslt_dir + "/samples") print("Making samples directory...") if K > 1: plot_z(z, K, title="most likely z for the ground truth") plt.savefig(rslt_dir + "/z.jpg") plt.close() if len(valid_data) > 0: plot_z(z_valid, K, title="most likely z for valid data") plt.savefig(rslt_dir + "/z_valid.jpg") plt.close() plot_z(sample_z, K, title="sample") plt.savefig(rslt_dir + "/samples/sample_z_{}.jpg".format(sample_T)) plt.close() plot_z(sample_z_center, K, title="sample (starting from center)") plt.savefig(rslt_dir + "/samples/sample_z_center_{}.jpg".format(sample_T)) plt.close() plt.figure(figsize=(4, 4)) plot_mouse(data, title="ground truth (training)", xlim=[ARENA_XMIN - 20, ARENA_YMAX + 20], ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20]) plt.legend() plt.savefig(rslt_dir + "/samples/ground_truth.jpg") plt.close() if len(valid_data) > 0: plt.figure(figsize=(4, 4)) plot_mouse(valid_data, title="ground truth (valid)", xlim=[ARENA_XMIN - 20, ARENA_YMAX + 20], ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20]) plt.legend() plt.savefig(rslt_dir + "/samples/ground_truth_valid.jpg") plt.close() plt.figure(figsize=(4, 4)) plot_mouse(sample_x, title="sample", xlim=[ARENA_XMIN - 20, ARENA_YMAX + 20], ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20]) plt.legend() plt.savefig(rslt_dir + "/samples/sample_x_{}.jpg".format(sample_T)) plt.close() plt.figure(figsize=(4, 4)) plot_mouse(sample_x_center, title="sample (starting from center)", xlim=[ARENA_XMIN - 20, ARENA_YMAX + 20], ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20]) plt.legend() plt.savefig(rslt_dir + "/samples/sample_x_center_{}.jpg".format(sample_T)) plt.close() if K > 1: plot_realdata_quiver(data, z, K, x_grids, y_grids, title="ground truth (training)") plt.savefig(rslt_dir + "/samples/quiver_ground_truth.jpg", dpi=200) plt.close() if len(valid_data) > 0: plot_realdata_quiver(valid_data, z_valid, K, x_grids, y_grids, title="ground truth (valid)") plt.savefig(rslt_dir + "/samples/quiver_ground_truth_valid.jpg", dpi=200) plt.close() plot_realdata_quiver(sample_x, sample_z, K, x_grids, y_grids, title="sample") plt.savefig(rslt_dir + "/samples/quiver_sample_x_{}.jpg".format(sample_T), dpi=200) plt.close() plot_realdata_quiver(sample_x_center, sample_z_center, K, x_grids, y_grids, title="sample (starting from center)") plt.savefig(rslt_dir + "/samples/quiver_sample_x_center_{}.jpg".format(sample_T), dpi=200) plt.close() if not os.path.exists(rslt_dir + "/dynamics"): os.makedirs(rslt_dir + "/dynamics") print("Making dynamics directory...") if animal == 'both': plot_quiver(XY_grids[..., 0:2], dXY[..., 0:2], 'virgin', K=K, scale=quiver_scale, alpha=0.9, title="quiver (virgin)", x_grids=x_grids, y_grids=y_grids, grid_alpha=0.2) plt.savefig(rslt_dir + "/dynamics/quiver_a.jpg", dpi=200) plt.close() plot_quiver(XY_grids[:, 2:4], dXY[..., 2:4], 'mother', K=K, scale=quiver_scale, alpha=0.9, title="quiver (mother)", x_grids=x_grids, y_grids=y_grids, grid_alpha=0.2) plt.savefig(rslt_dir + "/dynamics/quiver_b.jpg", dpi=200) plt.close() else: plot_quiver(XY_grids, dXY, animal, K=K, scale=quiver_scale, alpha=0.9, title="quiver ({})".format(animal), x_grids=x_grids, y_grids=y_grids, grid_alpha=0.2) plt.savefig(rslt_dir + "/dynamics/quiver_{}.jpg".format(animal), dpi=200) plt.close() if not os.path.exists(rslt_dir + "/distributions"): os.makedirs(rslt_dir + "/distributions") print("Making distributions directory...") # sanity checks plot_data_condition_on_all_zs(data, z, K, size=2, alpha=0.3) plt.savefig(rslt_dir + "/distributions/spatial_occup_groundtruth.jpg", dpi=100) plot_data_condition_on_all_zs(sample_x, sample_z, K, size=2, alpha=0.3) plt.savefig(rslt_dir + "/distributions/spatial_occup_sample_x.jpg", dpi=100) plot_data_condition_on_all_zs(sample_x_center, sample_z_center, K, size=2, alpha=0.3) plt.savefig(rslt_dir + "/distributions/spatial_occup_sample_x_center.jpg", dpi=100) plot_2d_time_plot_condition_on_all_zs(data, z, K, title='ground truth') plt.savefig(rslt_dir + "/distributions/4traces_groundtruth.jpg", dpi=100) plot_2d_time_plot_condition_on_all_zs(sample_x, sample_z, K, title='sample_x') plt.savefig(rslt_dir + "/distributions/4traces_sample_x.jpg", dpi=100) plot_2d_time_plot_condition_on_all_zs(sample_x_center, sample_z_center, K, title='sample_x_center') plt.savefig(rslt_dir + "/distributions/4traces_sample_x_center.jpg", dpi=100) data_angles = get_all_angles(data, x_grids, y_grids, device=device) sample_angles = get_all_angles(sample_x, x_grids, y_grids, device=device) sample_x_center_angles = get_all_angles(sample_x_center, x_grids, y_grids, device=device) if animal == 'both': plot_list_of_angles([data_angles[0], sample_angles[0], sample_x_center_angles], ['data', 'sample', 'sample_c'], "direction distribution (virgin)", n_x, n_y) plt.savefig(rslt_dir + "/distributions/angles_a.jpg") plt.close() plot_list_of_angles([data_angles[1], sample_angles[1], sample_x_center_angles], ['data', 'sample', 'sample_c'], "direction distribution (mother)", n_x, n_y) plt.savefig(rslt_dir + "/distributions/angles_b.jpg") plt.close() else: plot_list_of_angles([data_angles, sample_angles, sample_x_center_angles], ['data', 'sample', 'sample_c'], "direction distribution ({})".format(animal), n_x, n_y) plt.savefig(rslt_dir + "/distributions/angles_{}.jpg".format(animal)) plt.close() data_speed = get_speed(data, x_grids, y_grids, device=device) sample_speed = get_speed(sample_x, x_grids, y_grids, device=device) sample_x_center_speed = get_speed(sample_x_center, x_grids, y_grids, device=device) if animal == 'both': plot_list_of_speed([data_speed[0], sample_speed[0], sample_x_center_speed[0]], ['data', 'sample', 'sample_c'], "speed distribution (virgin)", n_x, n_y) plt.savefig(rslt_dir + "/distributions/speed_a.jpg") plt.close() plot_list_of_speed([data_speed[1], sample_speed[1], sample_x_center_speed[1]], ['data', 'sample', 'sample_c'], "speed distribution (mother)", n_x, n_y) plt.savefig(rslt_dir + "/distributions/speed_b.jpg") plt.close() else: plot_list_of_speed([data_speed, sample_speed, sample_x_center_speed], ['data', 'sample', 'sample_c'], "speed distribution ({})".format(animal), n_x, n_y) plt.savefig(rslt_dir + "/distributions/speed_{}.jpg".format(animal)) plt.close() try: if 100 < data.shape[0] <= 36000: plot_space_dist(data, x_grids, y_grids) elif data.shape[0] > 36000: plot_space_dist(data[:36000], x_grids, y_grids) plt.savefig(rslt_dir + "/distributions/space_data.jpg") plt.close() if 100 < sample_x.shape[0] <= 36000: plot_space_dist(sample_x, x_grids, y_grids) elif sample_x.shape[0] > 36000: plot_space_dist(sample_x[:36000], x_grids, y_grids) plt.savefig(rslt_dir + "/distributions/space_sample_x.jpg") plt.close() if 100 < sample_x_center.shape[0] <= 36000: plot_space_dist(sample_x_center, x_grids, y_grids) elif sample_x_center.shape[0] > 36000: plot_space_dist(sample_x_center[:36000], x_grids, y_grids) plt.savefig(rslt_dir + "/distributions/space_sample_x_center.jpg") plt.close() except: print("plot_space_dist unsuccessful")
def rslt_saving(rslt_dir, model, data, mouse, sample_T, train_model, losses, quiver_scale, x_grids, y_grids): tran = model.observation.transformation _, D = data.shape assert D == 2 or D == 4, "D must be either 2 or 4." n_x = len(x_grids) - 1 n_y = len(y_grids) - 1 K = model.K #################### inference ########################### print("\ninferring most likely states...") z = model.most_likely_states(data) print("0 step prediction") if data.shape[0] <= 5000: data_to_predict = data else: data_to_predict = data[-5000:] x_predict = k_step_prediction(model, z, data_to_predict) x_predict_err = np.mean(np.abs(x_predict - data_to_predict.numpy()), axis=0) print("5 step prediction") x_predict_5 = k_step_prediction(model, z, data_to_predict, k=5) x_predict_5_err = np.mean(np.abs(x_predict_5 - data_to_predict[5:].numpy()), axis=0) ################### samples ######################### sample_z, sample_x = model.sample(sample_T) center_z = torch.tensor([0], dtype=torch.int) if D == 4: center_x = torch.tensor([[150, 190, 200, 200]], dtype=torch.float64) else: center_x = torch.tensor([[150, 190]], dtype=torch.float64) sample_z_center, sample_x_center = model.sample(sample_T, prefix=(center_z, center_x)) ################## dynamics ##################### # quiver XX, YY = np.meshgrid(np.linspace(20, 310, 30), np.linspace(0, 380, 30)) XY = np.column_stack( (np.ravel(XX), np.ravel(YY))) # shape (900,2) grid values if D == 2: XY_grids = XY else: XY_grids = np.concatenate((XY, XY), axis=1) XY_next = tran.transform(torch.tensor(XY_grids, dtype=torch.float64)) dXY = XY_next.detach().numpy() - XY_grids[:, None] #################### saving ############################## print("begin saving...") # save summary avg_transform_speed = np.average(np.abs(dXY), axis=0) avg_sample_speed = np.average(np.abs(np.diff(sample_x, axis=0)), axis=0) avg_sample_center_speed = np.average(np.abs( np.diff(sample_x_center, axis=0)), axis=0) avg_data_speed = np.average(np.abs(np.diff(data.numpy(), axis=0)), axis=0) transition_matrix = model.transition.stationary_transition_matrix if transition_matrix.requires_grad: transition_matrix = transition_matrix.detach().numpy() else: transition_matrix = transition_matrix.numpy() cluster_centers = get_np(tran.mus_loc) summary_dict = { "init_dist": model.init_dist.detach().numpy(), "transition_matrix": transition_matrix, "x_predict_err": x_predict_err, "x_predict_5_err": x_predict_5_err, "mus": cluster_centers, "variance": torch.exp(model.observation.log_sigmas).detach().numpy(), "log_likes": model.log_likelihood(data).detach().numpy(), "avg_transform_speed": avg_transform_speed, "avg_data_speed": avg_data_speed, "avg_sample_speed": avg_sample_speed, "avg_sample_center_speed": avg_sample_center_speed } with open(rslt_dir + "/summary.json", "w") as f: json.dump(summary_dict, f, indent=4, cls=NumpyEncoder) # save numbers saving_dict = { "z": z, "x_predict": x_predict, "x_predict_5": x_predict_5, "sample_z": sample_z, "sample_x": sample_x, "sample_z_center": sample_z_center, "sample_x_center": sample_x_center } if train_model: saving_dict['losses'] = losses plt.figure() plt.plot(losses) plt.savefig(rslt_dir + "/losses.jpg") plt.close() joblib.dump(saving_dict, rslt_dir + "/numbers") # save figures plot_z(z, K, title="most likely z for the ground truth") plt.savefig(rslt_dir + "/z.jpg") plt.close() if not os.path.exists(rslt_dir + "/samples"): os.makedirs(rslt_dir + "/samples") print("Making samples directory...") plot_z(sample_z, K, title="sample") plt.savefig(rslt_dir + "/samples/sample_z_{}.jpg".format(sample_T)) plt.close() plot_z(sample_z_center, K, title="sample (starting from center)") plt.savefig(rslt_dir + "/samples/sample_z_center_{}.jpg".format(sample_T)) plt.close() plt.figure(figsize=(4, 4)) plot_mouse(data, title="ground truth_{}".format(mouse), xlim=[ARENA_XMIN - 20, ARENA_XMAX + 20], ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20]) plt.legend() plt.savefig(rslt_dir + "/samples/ground_truth.jpg") plt.close() plt.figure(figsize=(4, 4)) plot_mouse(sample_x, title="sample_{}".format(mouse), xlim=[ARENA_XMIN - 20, ARENA_XMAX + 20], ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20]) plt.legend() plt.savefig(rslt_dir + "/samples/sample_x_{}.jpg".format(sample_T)) plt.close() plt.figure(figsize=(4, 4)) plot_mouse(sample_x_center, title="sample (starting from center)_{}".format(mouse), xlim=[ARENA_XMIN - 20, ARENA_XMAX + 20], ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20]) plt.legend() plt.savefig(rslt_dir + "/samples/sample_x_center_{}.jpg".format(sample_T)) plt.close() plot_realdata_quiver(data, z, K, x_grids, y_grids, title="ground truth", cluster_centers=cluster_centers) plt.savefig(rslt_dir + "/samples/quiver_ground_truth.jpg", dpi=200) plot_realdata_quiver(sample_x, sample_z, K, x_grids, y_grids, title="sample", cluster_centers=cluster_centers) plt.savefig(rslt_dir + "/samples/quiver_sample_x_{}.jpg".format(sample_T), dpi=200) plt.close() plot_realdata_quiver(sample_x_center, sample_z_center, K, x_grids, y_grids, title="sample (starting from center)", cluster_centers=cluster_centers) plt.savefig(rslt_dir + "/samples/quiver_sample_x_center_{}.jpg".format(sample_T), dpi=200) plt.close() # plot mus plot_cluster_centers(cluster_centers, x_grids, y_grids) plt.savefig(rslt_dir + "/samples/cluster_centers.jpg", dpi=200) if not os.path.exists(rslt_dir + "/dynamics"): os.makedirs(rslt_dir + "/dynamics") print("Making dynamics directory...") if D == 2: plot_quiver(XY_grids, dXY, mouse, K=K, scale=quiver_scale, alpha=0.9, title="quiver ({})".format(mouse), x_grids=x_grids, y_grids=y_grids, grid_alpha=0.2) plt.savefig(rslt_dir + "/dynamics/quiver_{}.jpg".format(mouse), dpi=200) plt.close() else: plot_quiver(XY_grids[:, 0:2], dXY[..., 0:2], 'virgin', K=K, scale=quiver_scale, alpha=0.9, title="quiver (virgin)", x_grids=x_grids, y_grids=y_grids, grid_alpha=0.2) plt.savefig(rslt_dir + "/dynamics/quiver_a.jpg", dpi=200) plt.close() plot_quiver(XY_grids[:, 2:4], dXY[..., 2:4], 'mother', K=K, scale=quiver_scale, alpha=0.9, title="quiver (mother)", x_grids=x_grids, y_grids=y_grids, grid_alpha=0.2) plt.savefig(rslt_dir + "/dynamics/quiver_b.jpg", dpi=200) plt.close() if not os.path.exists(rslt_dir + "/distributions"): os.makedirs(rslt_dir + "/distributions") print("Making distributions directory...") if D == 4: data_angles_a, data_angles_b = get_all_angles(data, x_grids, y_grids) sample_angles_a, sample_angles_b = get_all_angles( sample_x, x_grids, y_grids) sample_x_center_angles_a, sample_x_center_angles_b = get_all_angles( sample_x_center, x_grids, y_grids) plot_list_of_angles( [data_angles_a, sample_angles_a, sample_x_center_angles_a], ['data', 'sample', 'sample_c'], "direction distribution (virgin)", n_x, n_y) plt.savefig(rslt_dir + "/distributions/angles_a.jpg") plt.close() plot_list_of_angles( [data_angles_b, sample_angles_b, sample_x_center_angles_b], ['data', 'sample', 'sample_c'], "direction distribution (mother)", n_x, n_y) plt.savefig(rslt_dir + "/distributions/angles_b.jpg") plt.close() data_speed_a, data_speed_b = get_speed(data, x_grids, y_grids) sample_speed_a, sample_speed_b = get_speed(sample_x, x_grids, y_grids) sample_x_center_speed_a, sample_x_center_speed_b = get_speed( sample_x_center, x_grids, y_grids) plot_list_of_speed( [data_speed_a, sample_speed_a, sample_x_center_speed_a], ['data', 'sample', 'sample_c'], "speed distribution (virgin)", n_x, n_y) plt.savefig(rslt_dir + "/distributions/speed_a.jpg") plt.close() plot_list_of_speed( [data_speed_b, sample_speed_b, sample_x_center_speed_b], ['data', 'sample', 'sample_c'], "speed distribution (mother)", n_x, n_y) plt.savefig(rslt_dir + "/distributions/speed_b.jpg") plt.close() else: data_angles_a = get_all_angles(data, x_grids, y_grids) sample_angles_a = get_all_angles(sample_x, x_grids, y_grids) sample_x_center_angles_a = get_all_angles(sample_x_center, x_grids, y_grids) plot_list_of_angles( [data_angles_a, sample_angles_a, sample_x_center_angles_a], ['data', 'sample', 'sample_c'], "direction distribution (virgin)", n_x, n_y) plt.savefig(rslt_dir + "/distributions/angles_{}.jpg".format(mouse)) plt.close() data_speed_a = get_speed(data, x_grids, y_grids) sample_speed_a = get_speed(sample_x, x_grids, y_grids) sample_x_center_speed_a = get_speed(sample_x_center, x_grids, y_grids) plot_list_of_speed( [data_speed_a, sample_speed_a, sample_x_center_speed_a], ['data', 'sample', 'sample_c'], "speed distribution (virgin)", n_x, n_y) plt.savefig(rslt_dir + "/distributions/speed_{}.jpg".format(mouse)) plt.close() try: if 100 < data.shape[0] <= 36000: plot_space_dist(data, x_grids, y_grids) elif data.shape[0] > 36000: plot_space_dist(data[:36000], x_grids, y_grids) plt.savefig(rslt_dir + "/distributions/space_data.jpg") plt.close() if 100 < sample_x.shape[0] <= 36000: plot_space_dist(sample_x, x_grids, y_grids) elif sample_x.shape[0] > 36000: plot_space_dist(sample_x[:36000], x_grids, y_grids) plt.savefig(rslt_dir + "/distributions/space_sample_x.jpg") plt.close() if 100 < sample_x_center.shape[0] <= 36000: plot_space_dist(sample_x_center, x_grids, y_grids) elif sample_x_center.shape[0] > 36000: plot_space_dist(sample_x_center[:36000], x_grids, y_grids) plt.savefig(rslt_dir + "/distributions/space_sample_x_center.jpg") plt.close() except: print("plot_space_dist unsuccessful")
def test_model(): torch.manual_seed(0) np.random.seed(0) T = 100 D = 4 # data = np.array([[1.0, 1.0, 1.0, 6.0], [3.0, 6.0, 8.0, 6.0], # [4.0, 7.0, 8.0, 5.0], [6.0, 7.0, 5.0, 6.0], [8.0, 2.0, 6.0, 1.0]]) data = np.random.randn(T, D) data = torch.tensor(data, dtype=torch.float64) xmax = max(np.max(data[:, 0].numpy()), np.max(data[:, 2].numpy())) xmin = min(np.min(data[:, 0].numpy()), np.min(data[:, 2].numpy())) ymax = max(np.max(data[:, 1].numpy()), np.max(data[:, 3].numpy())) ymin = min(np.min(data[:, 1].numpy()), np.min(data[:, 3].numpy())) bounds = np.array([[xmin - 1, xmax + 1], [ymin - 1, ymax + 1], [xmin - 1, xmax + 1], [ymin - 1, ymax + 1]]) def toy_feature_vec_func(s): """ :param s: self, (T, 2) :param o: other, (T, 2) :return: features, (T, Df, 2) """ corners = torch.tensor([[0, 0], [0, 8], [10, 0], [10, 8]], dtype=torch.float64) return feature_direction_vec(s, corners) K = 3 Df = 4 lags = 1 tran = UniLSTMTransformation(K=K, D=D, Df=Df, feature_vec_func=toy_feature_vec_func, lags=lags, dh=10) # observation obs = ARTruncatedNormalObservation(K=K, D=D, lags=lags, bounds=bounds, transformation=tran) # model model = HMM(K=K, D=D, observation=obs) print("calculating log likelihood") feature_vecs_a = toy_feature_vec_func(data[:-1, 0:2]) feature_vecs_b = toy_feature_vec_func(data[:-1, 2:4]) feature_vecs = (feature_vecs_a, feature_vecs_b) packed_data = get_packed_data((data[:-1]), lags=lags) model.log_likelihood(data, feature_vecs=feature_vecs, packed_data=packed_data) # fit losses, _ = model.fit(data, optimizer=None, method="adam", num_iters=50, feature_vecs=feature_vecs, packed_data=packed_data) plt.figure() plt.plot(losses) plt.show() # most-likely-z print("Most likely z...") z = model.most_likely_states(data, feature_vecs=feature_vecs, packed_data=packed_data) # prediction if data.shape[0] <= 1000: data_to_predict = data else: data_to_predict = data[-1000:] print("0 step prediction") if data.shape[0] <= 1000: data_to_predict = data else: data_to_predict = data[-1000:] x_predict = k_step_prediction_for_lstm_model(model, z, data_to_predict, feature_vecs=feature_vecs) x_predict_err = np.mean(np.abs(x_predict - data_to_predict.numpy()), axis=0) print("10 step prediction") x_predict_2 = k_step_prediction(model, z, data_to_predict, k=10) x_predict_2_err = np.mean(np.abs(x_predict_2 - data_to_predict[10:].numpy()), axis=0) # samples print("sampling...") sample_T = 5 sample_z, sample_x = model.sample(sample_T)
def k_step_prediction_for_lstm_based_model(model, model_z, data, k=0, feature_vecs=None): data = check_and_convert_to_tensor(data) T, D = data.shape lstm_states = {} x_predict_arr = [] if k == 0: if feature_vecs is None: print("Did not provide memory information") return k_step_prediction(model, model_z, data) else: feature_vecs_a, feature_vecs_b = feature_vecs x_predict = model.observation.sample_x(model_z[0], data[:0], return_np=True) x_predict_arr.append(x_predict) for t in range(1, data.shape[0]): feature_vec_t = (feature_vecs_a[t - 1:t], feature_vecs_b[t - 1:t]) x_predict = model.observation.sample_x( model_z[t], data[:t], return_np=True, with_noise=True, feature_vec=feature_vec_t, lstm_states=lstm_states) x_predict_arr.append(x_predict) else: assert k > 0 # neglects t = 0 since there is no history if T <= k: raise ValueError("Please input k such that k < {}.".format(T)) for t in range(1, T - k + 1): # sample k steps forward # first step use real value z, x = model.sample(1, prefix=(model_z[t - 1:t], data[t - 1:t]), return_np=False, with_noise=True, lstm_states=lstm_states) # last k-1 steps use sampled value if k >= 1: sampled_lstm_states = dict(h_t=lstm_states["h_t"], c_t=lstm_states["c_t"]) for i in range(k - 1): z, x = model.sample(1, prefix=(z, x), return_np=False, with_noise=True, lstm_states=sampled_lstm_states) assert x.shape == (1, D) x_predict_arr.append(get_np(x[0])) x_predict_arr = np.array(x_predict_arr) assert x_predict_arr.shape == (T - k, D) return x_predict_arr
def test_model(): torch.manual_seed(0) np.random.seed(0) T = 5 x_grids = np.array([0.0, 5.0, 10.0]) y_grids = np.array([0.0, 4.0, 8.0]) data = np.array([[1.0, 1.0, 1.0, 6.0], [3.0, 6.0, 8.0, 6.0], [4.0, 7.0, 8.0, 5.0], [6.0, 7.0, 5.0, 6.0], [8.0, 2.0, 6.0, 1.0]]) data = torch.tensor(data, dtype=torch.float64) def toy_feature_vec_func(s): """ :param s: self, (T, 2) :param o: other, (T, 2) :return: features, (T, Df, 2) """ corners = torch.tensor([[0, 0], [0, 8], [10, 0], [10, 8]], dtype=torch.float64) return feature_direction_vec(s, corners) K = 3 D = 4 M = 0 Df = 4 bounds = np.array([[0.0, 10.0], [0.0, 8.0], [0.0, 10.0], [0.0, 8.0]]) tran = LinearGridTransformation(K=K, D=D, x_grids=x_grids, y_grids=y_grids, Df=Df, feature_vec_func=toy_feature_vec_func) obs = ARTruncatedNormalObservation(K=K, D=D, M=0, lags=1, bounds=bounds, transformation=tran) model = HMM(K=K, D=D, M=M, transition="stationary", observation=obs) model.observation.mus_init = data[0] * torch.ones( K, D, dtype=torch.float64) # calculate memory gridpoints_idx_a = tran.get_gridpoints_idx_for_batch(data[:-1, 0:2]) gridpoints_idx_b = tran.get_gridpoints_idx_for_batch(data[:-1, 2:4]) gridpoints_a = tran.get_gridpoints_for_batch(gridpoints_idx_a) gridpoints_b = tran.get_gridpoints_for_batch(gridpoints_idx_b) feature_vecs_a = toy_feature_vec_func(data[:-1, 0:2]) feature_vecs_b = toy_feature_vec_func(data[:-1, 2:4]) gridpoints_idx = (gridpoints_idx_a, gridpoints_idx_b) gridpoints = (gridpoints_a, gridpoints_b) feature_vecs = (feature_vecs_a, feature_vecs_b) # fit losses, opt = model.fit(data, optimizer=None, method='adam', num_iters=100, lr=0.01, pbar_update_interval=10, gridpoints=gridpoints, gridpoints_idx=gridpoints_idx, feature_vecs=feature_vecs) plt.figure() plt.plot(losses) plt.show() # most-likely-z print("Most likely z...") z = model.most_likely_states(data, gridpoints_idx=gridpoints_idx, feature_vecs=feature_vecs) # prediction print("0 step prediction") if data.shape[0] <= 1000: data_to_predict = data else: data_to_predict = data[-1000:] x_predict = k_step_prediction_for_lineargrid_model( model, z, data_to_predict, gridpoints_idx=gridpoints_idx, feature_vecs=feature_vecs) x_predict_err = np.mean(np.abs(x_predict - data_to_predict.numpy()), axis=0) print("2 step prediction") x_predict_2 = k_step_prediction(model, z, data_to_predict, k=2) x_predict_2_err = np.mean(np.abs(x_predict_2 - data_to_predict[2:].numpy()), axis=0) # samples sample_T = 5 sample_z, sample_x = model.sample(sample_T)
def test_model(): torch.manual_seed(0) np.random.seed(0) T = 5 x_grids = np.array([0.0, 10.0]) y_grids = np.array([0.0, 8.0]) bounds = np.array([[0.0, 10.0], [0.0, 8.0], [0.0, 10.0], [0.0, 8.0]]) data = np.array([[1.0, 1.0, 1.0, 6.0], [3.0, 6.0, 8.0, 6.0], [4.0, 7.0, 8.0, 5.0], [6.0, 7.0, 5.0, 6.0], [8.0, 2.0, 6.0, 1.0]]) data = torch.tensor(data, dtype=torch.float64) K = 3 D = 4 M = 0 obs = GPObservation(K=K, D=D, x_grids=x_grids, y_grids=y_grids, bounds=bounds, train_rs=True) correct_kerneldist_gg = torch.tensor( [[0., 0., 64., 64., 100., 100., 164., 164.], [0., 0., 64., 64., 100., 100., 164., 164.], [64., 64., 0., 0., 164., 164., 100., 100.], [64., 64., 0., 0., 164., 164., 100., 100.], [100., 100., 164., 164., 0., 0., 64., 64.], [100., 100., 164., 164., 0., 0., 64., 64.], [164., 164., 100., 100., 64., 64., 0., 0.], [164., 164., 100., 100., 64., 64., 0., 0.]], dtype=torch.float64) assert torch.all(torch.eq(correct_kerneldist_gg, obs.kernel_distsq_gg)), obs.kernel_distsq_gg log_prob_nocache = obs.log_prob(data) print("log_prob_nocache = {}".format(log_prob_nocache)) kernel_distsq_xg_a = kernel_distsq_doubled(data[:-1, 0:2], obs.inducing_points) kernel_distsq_xg_b = kernel_distsq_doubled(data[:-1, 2:4], obs.inducing_points) correct_kernel_distsq_xg_a = torch.tensor( [[2., 2., 50., 50., 82., 82., 130., 130.], [2., 2., 50., 50., 82., 82., 130., 130.], [45., 45., 13., 13., 85., 85., 53., 53.], [45., 45., 13., 13., 85., 85., 53., 53.], [65., 65., 17., 17., 85., 85., 37., 37.], [65., 65., 17., 17., 85., 85., 37., 37.], [85., 85., 37., 37., 65., 65., 17., 17.], [85., 85., 37., 37., 65., 65., 17., 17.]], dtype=torch.float64) assert torch.all(torch.eq(correct_kernel_distsq_xg_a, kernel_distsq_xg_a)), kernel_distsq_xg_a memory_kwargs = dict(kernel_distsq_xg_a=kernel_distsq_xg_a, kernel_distsq_xg_b=kernel_distsq_xg_b) log_prob = obs.log_prob(data, **memory_kwargs) print("log_prob = {}".format(log_prob)) assert torch.all(torch.eq(log_prob_nocache, log_prob)) Sigma_a, A_a = obs.get_gp_cache(data[:-1, 0:2], 0, **memory_kwargs) Sigma_b, A_b = obs.get_gp_cache(data[:-1, 2:4], 1, **memory_kwargs) memory_kwargs_2 = dict(Sigma_a=Sigma_a, A_a=A_a, Sigma_b=Sigma_b, A_b=A_b) print("calculating log prob 2...") log_prob2 = obs.log_prob(data, **memory_kwargs_2) assert torch.all(torch.eq(log_prob, log_prob2)) model = HMM(K=K, D=D, M=M, transition="stationary", observation=obs) model.observation.mus_init = data[0] * torch.ones( K, D, dtype=torch.float64) # fit losses, opt = model.fit(data, optimizer=None, method='adam', num_iters=100, lr=0.01, pbar_update_interval=10, **memory_kwargs) plt.figure() plt.plot(losses) plt.show() # most-likely-z print("Most likely z...") z = model.most_likely_states(data, **memory_kwargs) # prediction print("0 step prediction") if data.shape[0] <= 1000: data_to_predict = data else: data_to_predict = data[-1000:] x_predict = k_step_prediction_for_gpmodel(model, z, data_to_predict, **memory_kwargs) x_predict_err = np.mean(np.abs(x_predict - data_to_predict.numpy()), axis=0) print("2 step prediction") x_predict_2 = k_step_prediction(model, z, data_to_predict, k=2) x_predict_2_err = np.mean(np.abs(x_predict_2 - data_to_predict[2:].numpy()), axis=0) # samples sample_T = 5 sample_z, sample_x = model.sample(sample_T)
joblib.dump(losses, "losses") plt.plot(losses) # sampling print("start sampling") sample_z, sample_x = model.sample(T) plot_realdata_quiver(sample_x, scale=1, title="sample") plt.show() joblib.dump((sample_z, sample_x), "samples") # inference print("inferiring most likely states...") z = model.most_likely_states(data, memory_kwargs_a=m_kwargs_a, memory_kwargs_b=m_kwargs_b) joblib.dump(z, "z") data_to_predict = data[-1000:] print("0 step prediction") x_predict = k_step_prediction_for_grid_model(model, z, data_to_predict, memory_kwargs_a=m_kwargs_a, memory_kwargs_b=m_kwargs_b) err = np.mean(np.abs(x_predict - data_to_predict.numpy()), axis=0) print(err) print("k step prediction") x_predict_5 = k_step_prediction(model, z, data_to_predict, 5) err = np.mean(np.abs(x_predict_5 - data_to_predict[5:].numpy()), axis=0) print(err)
# model model = HMM(K=K, D=D, M=M, observation=obs) # log like log_prob = model.log_probability(data) print("log probability = ", log_prob) # training print("start training...") num_iters = 10 losses, opt = model.fit(data, num_iters=num_iters, lr=0.001) # sampling samplt_T = T print("start sampling") sample_z, sample_x = model.sample(samplt_T) print("start sampling based on with_noise") sample_z2, sample_x2 = model.sample(T, with_noise=True) # inference print("inferiring most likely states...") z = model.most_likely_states(data) print("0 step prediction") x_predict = k_step_prediction(model, z, data) print("k step prediction") x_predict_10 = k_step_prediction(model, z, data, 2)
losses = [] for i in np.arange(num_iters): optimizer.zero_grad() loss = model.loss(data) loss.backward(retain_graph=True) optimizer.step() loss = loss.detach().numpy() losses.append(loss) if i % 10 == 0: pbar.set_description('iter {} loss {:.2f}'.format(i, loss)) pbar.update(10) # check reconstruction x_reconstruct = model.sample_condition_on_zs(z, data[0]) # infer the latent states infer_z = model.most_likely_states(data) perm = find_permutation(z.numpy(), infer_z, K1=K, K2=K) model.permute(perm) hmm_z = model.most_likely_states(data) # check prediction x_predict_cond_z = k_step_prediction(model, z, data)
lags = 1 bounds = np.array([[-2, 2], [0, 1], [-2, 2], [0, 1]]) As = np.array([ np.column_stack([np.identity(D), np.zeros((D, (lags - 1) * D))]) for _ in range(K) ]) torch.manual_seed(0) np.random.seed(0) tran = LinearTransformation(K=K, D=D, lags=lags, As=As) observation = ARTruncatedNormalObservation(K=K, D=D, M=0, transformation=tran, bounds=bounds) model = HMM(K=K, D=D, M=0, observation=observation) lls = model.log_likelihood(data) print(lls) #losses_1, optimizer_1 = model_1.fit(data_1, method='adam', num_iters=2000, lr=0.001) z_1 = model.most_likely_states(data) x_predict_arr_lag1 = k_step_prediction(model, z_1, data)
def rslt_saving(rslt_dir, model, Df, data, masks_a, masks_b, m_kwargs_a, m_kwargs_b, sample_T, train_model, losses, quiver_scale): tran = model.observation.transformation x_grids = tran.x_grids y_grids = tran.y_grids n_x = len(x_grids) - 1 n_y = len(y_grids) - 1 G = n_x * n_y f_corner_vec_func = tran.transformations_a[0].feature_vec_func K = model.K Df = Df #################### inference ########################### print("\ninferring most likely states...") z = model.most_likely_states(data, masks=(masks_a, masks_b), memory_kwargs_a=m_kwargs_a, memory_kwargs_b=m_kwargs_b) print("0 step prediction") if data.shape[0] <= 1000: data_to_predict = data else: data_to_predict = data[-1000:] x_predict = k_step_prediction_for_grid_model(model, z, data_to_predict, memory_kwargs_a=m_kwargs_a, memory_kwargs_b=m_kwargs_b) x_predict_err = np.mean(np.abs(x_predict - data_to_predict.numpy()), axis=0) print("5 step prediction") x_predict_5 = k_step_prediction(model, z, data_to_predict, k=5) x_predict_5_err = np.mean(np.abs(x_predict_5 - data_to_predict[5:].numpy()), axis=0) ################### samples ######################### sample_z, sample_x = model.sample(sample_T) center_z = torch.tensor([0], dtype=torch.int) center_x = torch.tensor([[150, 190, 200, 200]], dtype=torch.float64) sample_z_center, sample_x_center = model.sample(sample_T, prefix=(center_z, center_x)) ################## dynamics ##################### # weights weights_a = np.array( [t.weights.detach().numpy() for t in tran.transformations_a]) weights_b = np.array( [t.weights.detach().numpy() for t in tran.transformations_b]) # dynamics grid_centers = np.array([[ 1 / 2 * (x_grids[i] + x_grids[i + 1]), 1 / 2 * (y_grids[j] + y_grids[j + 1]) ] for i in range(n_x) for j in range(n_y)]) unit_corner_vecs = f_corner_vec_func( torch.tensor(grid_centers, dtype=torch.float64)) unit_corner_vecs = unit_corner_vecs.numpy() # (G, 1, Df, d) * (G, K, Df, 1) --> (G, K, Df, d) weighted_corner_vecs_a = unit_corner_vecs[:, None] * weights_a[..., None] weighted_corner_vecs_b = unit_corner_vecs[:, None] * weights_b[..., None] grid_z_a_percentage = get_z_percentage_by_grid(masks_a, z, K, G) grid_z_b_percentage = get_z_percentage_by_grid(masks_b, z, K, G) # quiver XX, YY = np.meshgrid(np.linspace(20, 310, 30), np.linspace(0, 380, 30)) XY = np.column_stack( (np.ravel(XX), np.ravel(YY))) # shape (900,2) grid values XY_grids = np.concatenate((XY, XY), axis=1) XY_next = tran.transform(torch.tensor(XY_grids, dtype=torch.float64)) dXY = XY_next.detach().numpy() - XY_grids[:, None] #################### saving ############################## print("begin saving...") # save summary avg_transform_speed = np.average(np.abs(dXY), axis=0) avg_sample_speed = np.average(np.abs(np.diff(sample_x, axis=0)), axis=0) avg_sample_center_speed = np.average(np.abs( np.diff(sample_x_center, axis=0)), axis=0) avg_data_speed = np.average(np.abs(np.diff(data.numpy(), axis=0)), axis=0) transition_matrix = model.transition.stationary_transition_matrix if transition_matrix.requires_grad: transition_matrix = transition_matrix.detach().numpy() else: transition_matrix = transition_matrix.numpy() summary_dict = { "init_dist": model.init_dist.detach().numpy(), "transition_matrix": transition_matrix, "x_predict_err": x_predict_err, "x_predict_5_err": x_predict_5_err, "variance": torch.exp(model.observation.log_sigmas).detach().numpy(), "log_likes": model.log_likelihood(data).detach().numpy(), "grid_z_a_percentage": grid_z_a_percentage, "grid_z_b_percentage": grid_z_b_percentage, "avg_transform_speed": avg_transform_speed, "avg_data_speed": avg_data_speed, "avg_sample_speed": avg_sample_speed, "avg_sample_center_speed": avg_sample_center_speed } with open(rslt_dir + "/summary.json", "w") as f: json.dump(summary_dict, f, indent=4, cls=NumpyEncoder) # save numbers saving_dict = { "z": z, "x_predict": x_predict, "x_predict_5": x_predict_5, "sample_z": sample_z, "sample_x": sample_x, "sample_z_center": sample_z_center, "sample_x_center": sample_x_center } if train_model: saving_dict['losses'] = losses plt.figure() plt.plot(losses) plt.savefig(rslt_dir + "/losses.jpg") plt.close() joblib.dump(saving_dict, rslt_dir + "/numbers") # save figures plot_z(z, K, title="most likely z for the ground truth") plt.savefig(rslt_dir + "/z.jpg") plt.close() if not os.path.exists(rslt_dir + "/samples"): os.makedirs(rslt_dir + "/samples") print("Making samples directory...") plot_z(sample_z, K, title="sample") plt.savefig(rslt_dir + "/samples/sample_z_{}.jpg".format(sample_T)) plt.close() plot_z(sample_z_center, K, title="sample (starting from center)") plt.savefig(rslt_dir + "/samples/sample_z_center_{}.jpg".format(sample_T)) plt.close() plt.figure(figsize=(4, 4)) plot_mouse(data, title="ground truth", xlim=[ARENA_XMIN - 20, ARENA_XMAX + 20], ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20]) plt.legend() plt.savefig(rslt_dir + "/samples/ground_truth.jpg") plt.close() plt.figure(figsize=(4, 4)) plot_mouse(sample_x, title="sample", xlim=[ARENA_XMIN - 20, ARENA_XMAX + 20], ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20]) plt.legend() plt.savefig(rslt_dir + "/samples/sample_x_{}.jpg".format(sample_T)) plt.close() plt.figure(figsize=(4, 4)) plot_mouse(sample_x_center, title="sample (starting from center)", xlim=[ARENA_XMIN - 20, ARENA_XMAX + 20], ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20]) plt.legend() plt.savefig(rslt_dir + "/samples/sample_x_center_{}.jpg".format(sample_T)) plt.close() plot_realdata_quiver(data, z, K, x_grids, y_grids, title="ground truth") plt.savefig(rslt_dir + "/samples/ground_truth_quiver.jpg") plot_realdata_quiver(sample_x, sample_z, K, x_grids, y_grids, title="sample") plt.savefig(rslt_dir + "/samples/sample_x_quiver_{}.jpg".format(sample_T)) plt.close() plot_realdata_quiver(sample_x_center, sample_z_center, K, x_grids, y_grids, title="sample (starting from center)") plt.savefig(rslt_dir + "/samples/sample_x_center_quiver_{}.jpg".format(sample_T)) plt.close() if not os.path.exists(rslt_dir + "/dynamics"): os.makedirs(rslt_dir + "/dynamics") print("Making dynamics directory...") plot_weights(weights_a, Df, K, x_grids, y_grids, max_weight=tran.transformations_a[0].acc_factor, title="weights (virgin)") plt.savefig(rslt_dir + "/dynamics/weights_a.jpg") plt.close() plot_weights(weights_b, Df, K, x_grids, y_grids, max_weight=tran.transformations_b[0].acc_factor, title="weights (mother)") plt.savefig(rslt_dir + "/dynamics/weights_b.jpg") plt.close() plot_dynamics(weighted_corner_vecs_a, "virgin", x_grids, y_grids, K=K, scale=quiver_scale, percentage=grid_z_a_percentage, title="grid dynamics (virgin)") plt.savefig(rslt_dir + "/dynamics/dynamics_a.jpg") plt.close() plot_dynamics(weighted_corner_vecs_b, "mother", x_grids, y_grids, K=K, scale=quiver_scale, percentage=grid_z_b_percentage, title="grid dynamics (mother)") plt.savefig(rslt_dir + "/dynamics/dynamics_b.jpg") plt.close() plot_quiver(XY_grids[:, 0:2], dXY[..., 0:2], 'virgin', K=K, scale=quiver_scale, alpha=0.9, title="quiver (virgin)") plt.savefig(rslt_dir + "/dynamics/quiver_a.jpg") plt.close() plot_quiver(XY_grids[:, 2:4], dXY[..., 2:4], 'mother', K=K, scale=quiver_scale, alpha=0.9, title="quiver (mother)") plt.savefig(rslt_dir + "/dynamics/quiver_b.jpg") plt.close() if not os.path.exists(rslt_dir + "/distributions"): os.makedirs(rslt_dir + "/distributions") print("Making distributions directory...") data_angles_a, data_angles_b = get_all_angles(data, x_grids, y_grids) sample_angles_a, sample_angles_b = get_all_angles(sample_x, x_grids, y_grids) sample_x_center_angles_a, sample_x_center_angles_b = get_all_angles( sample_x_center, x_grids, y_grids) plot_list_of_angles( [data_angles_a, sample_angles_a, sample_x_center_angles_a], ['data', 'sample', 'sample_c'], "direction distribution (virgin)", n_x, n_y) plt.savefig(rslt_dir + "/distributions/angles_a.jpg") plt.close() plot_list_of_angles( [data_angles_b, sample_angles_b, sample_x_center_angles_b], ['data', 'sample', 'sample_c'], "direction distribution (mother)", n_x, n_y) plt.savefig(rslt_dir + "/distributions/angles_b.jpg") plt.close() data_speed_a, data_speed_b = get_speed(data, x_grids, y_grids) sample_speed_a, sample_speed_b = get_speed(sample_x, x_grids, y_grids) sample_x_center_speed_a, sample_x_center_speed_b = get_speed( sample_x_center, x_grids, y_grids) plot_list_of_speed([data_speed_a, sample_speed_a, sample_x_center_speed_a], ['data', 'sample', 'sample_c'], "speed distribution (virgin)", n_x, n_y) plt.savefig(rslt_dir + "/distributions/speed_a.jpg") plt.close() plot_list_of_speed([data_speed_b, sample_speed_b, sample_x_center_speed_b], ['data', 'sample', 'sample_c'], "speed distribution (mother)", n_x, n_y) plt.savefig(rslt_dir + "/distributions/speed_b.jpg") plt.close() try: if 100 < data.shape[0] <= 36000: plot_space_dist(data, x_grids, y_grids) elif data.shape[0] > 36000: plot_space_dist(data[:36000], x_grids, y_grids) plt.savefig(rslt_dir + "/distributions/space_data.jpg") plt.close() if 100 < sample_x.shape[0] <= 36000: plot_space_dist(sample_x, x_grids, y_grids) elif sample_x.shape[0] > 36000: plot_space_dist(sample_x[:36000], x_grids, y_grids) plt.savefig(rslt_dir + "/distributions/space_sample_x.jpg") plt.close() if 100 < sample_x_center.shape[0] <= 36000: plot_space_dist(sample_x_center, x_grids, y_grids) elif sample_x_center.shape[0] > 36000: plot_space_dist(sample_x_center[:36000], x_grids, y_grids) plt.savefig(rslt_dir + "/distributions/space_sample_x_center.jpg") plt.close() except: print("plot_space_dist unsuccessful")