def most_likely_states(self, data, input=None, transition_mkwargs=None, **memory_kwargs): if len(data) == 0: return np.array([]) if isinstance(self.transition, InputDrivenTransition) and input is None: raise ValueError("Please provide input.") if input is not None: input = check_and_convert_to_tensor(input, device=self.device) data = check_and_convert_to_tensor(data, device=self.device) T = data.shape[0] log_pi0 = get_np(self.init_state_distn.log_probs) if isinstance(self.transition, StationaryTransition): log_Ps = self.transition.log_stationary_transition_matrix log_Ps = get_np(log_Ps) # (K, K) log_Ps = log_Ps[None, ] else: assert isinstance(self.transition, GridTransition), type(self.transition) transition_mkwargs = transition_mkwargs if transition_mkwargs else {} input = input[:-1] if input else input log_Ps = self.transition.log_transition_matrix( data[:-1], input, **transition_mkwargs) log_Ps = get_np(log_Ps) assert log_Ps.shape == (T - 1, self.K, self.K), log_Ps.shape log_likes = get_np(self.observation.log_prob(data, **memory_kwargs)) return viterbi(log_pi0, log_Ps, log_likes)
def plot_grid_transition(n_x, n_y, grid_transition): """ Note: this is for single animal case plot the grid transition matrices. return a Figure object """ # TODO: fit for n_x = 1 or n_y = 1 assert n_x > 1 and n_y > 1 fig, axn = plt.subplots(n_y, n_x, sharex=True, sharey=True, figsize=(10, 10)) cbar_ax = fig.add_axes([.93, .3, .03, .4]) # n_x corresponds to the number of columns, and n_y corresponds to the number of rows. grid_idx = 0 for i in range(n_x): for j in range(n_y): # plot_idx = ij_to_plot_idx(i, j, n_x, n_y) # plt.subplot(n_x, n_y, plot_idx) ax = axn[n_y - j - 1][i] if i == 0 and j == 0: sns.heatmap(get_np(grid_transition[grid_idx]), ax=ax, vmin=0, vmax=1, cmap="BuGn", square=True, cbar_ax=cbar_ax) else: sns.heatmap(get_np(grid_transition[grid_idx]), ax=ax, vmin=0, vmax=1, cmap="BuGn", square=True, cbar=False) ax.tick_params(axis='both', which='both', length=0) grid_idx += 1 plt.tight_layout(rect=[0, 0, .9, 1])
def plot_grid_transition(self, axn, cbar_ax, n_x, n_y, grid_transition): """ Note: this is for single animal case plot the grid transition matrices. return a Figure object """ # n_x corresponds to the number of columns, and n_y corresponds to the number of rows. grid_idx = 0 for i in range(n_x): for j in range(n_y): # plot_idx = ij_to_plot_idx(i, j, n_x, n_y) # plt.subplot(n_x, n_y, plot_idx) ax = axn[n_y - j - 1][i] ax.clear() if i == 0 and j == 0: sns.heatmap(get_np(grid_transition[grid_idx]), ax=ax, vmin=0, vmax=1, cmap="BuGn", square=True, cbar_ax=cbar_ax) else: sns.heatmap(get_np(grid_transition[grid_idx]), ax=ax, vmin=0, vmax=1, cmap="BuGn", square=True, cbar=False) ax.tick_params(axis='both', which='both', length=0) grid_idx += 1
def get_speed_for_single_animal(data, x_grids, y_grids, device=torch.device('cpu')): # data: (T, 2) _, D = data.shape assert D == 2 if isinstance(data, np.ndarray): diff = np.diff(data, axis=0) # (T-1, 2) data = torch.tensor(data, dtype=torch.float64, device=device) masks_a = get_masks_for_single_animal(data[:-1], x_grids, y_grids) elif isinstance(data, torch.Tensor): masks_a = get_masks_for_single_animal(data[:-1], x_grids, y_grids) diff = np.diff(get_np(data), axis=0) # (T-1, 2) else: raise ValueError("Data must be either np.ndarray or torch.Tensor") speed_a_all = np.sqrt(diff[:, 0] ** 2 + diff[:, 1] ** 2) # (T-1, 2) speed_a = [] G = (len(x_grids) - 1) * (len(y_grids) - 1) for g in range(G): speed_a_g = speed_a_all[get_np(masks_a[g]) == 1] speed_a.append(speed_a_g) return speed_a
def plot_quiver(self, axs, K, scale=1, alpha=1, x_grids=None, y_grids=None, grid_alpha=1): #TODO: for single animal now # quiver XX, YY = np.meshgrid(np.linspace(20, 310, 30), np.linspace(0, 380, 30)) XYs = np.column_stack( (np.ravel(XX), np.ravel(YY))) # shape (900,2) grid values if isinstance(self.observation, GPObservationSingle): XY_next, _ = self.observation.get_mu_and_cov_for_single_animal( XYs, mu_only=True) else: XY_next = self.observation.transformation.transform( torch.tensor(XYs, dtype=torch.float64, device=self.device)) dXYs = get_np(XY_next) - XYs[:, None] if K == 1: axs.clear() axs.quiver(XYs[:, 0], XYs[:, 1], dXYs[:, 0, 0], dXYs[:, 0, 1], angles='xy', scale_units='xy', scale=scale, alpha=alpha) # add_grid(x_grids, y_grids, grid_alpha=grid_alpha) if isinstance(self.observation, GPObservationSingle): axs.set_title("K={}, rs={}".format( 0, get_np(self.observation.rs[0]))) else: axs.set_title("K={}".format(0)) else: for k in range(K): axs[k].clear() axs[k].quiver(XYs[:, 0], XYs[:, 1], dXYs[:, k, 0], dXYs[:, k, 1], angles='xy', scale_units='xy', scale=scale, alpha=alpha) #add_grid(x_grids, y_grids, grid_alpha=grid_alpha) if isinstance(self.observation, GPObservationSingle): axs[k].set_title('K={}, rs={}'.format( k, get_np(self.observation.rs[k]))) else: axs[k].set_title('K={}'.format(k))
def add_grid_to_ax(ax, x_grids, y_grids): if isinstance(x_grids, torch.Tensor): x_grids = get_np(x_grids) if isinstance(y_grids, torch.Tensor): y_grids = get_np(y_grids) ax.scatter([x_grids[0], x_grids[0], x_grids[-1], x_grids[-1]], [y_grids[0], y_grids[-1], y_grids[0], y_grids[-1]]) for j in range(len(y_grids)): ax.plot([x_grids[0], x_grids[-1]], [y_grids[j], y_grids[j]], '--', color='grey') for i in range(len(x_grids)): ax.plot([x_grids[i], x_grids[i]], [y_grids[0], y_grids[-1]], '--', color='grey')
def add_grid(x_grids, y_grids, grid_alpha=1.0): if x_grids is None or y_grids is None: return if isinstance(x_grids, torch.Tensor): x_grids = get_np(x_grids) if isinstance(y_grids, torch.Tensor): y_grids = get_np(y_grids) plt.scatter([x_grids[0], x_grids[0], x_grids[-1], x_grids[-1]], [y_grids[0], y_grids[1], y_grids[0], y_grids[1]], alpha=grid_alpha) for j in range(len(y_grids)): plt.plot([x_grids[0], x_grids[-1]], [y_grids[j], y_grids[j]], '--', color='grey', alpha=grid_alpha) for i in range(len(x_grids)): plt.plot([x_grids[i], x_grids[i]], [y_grids[0], y_grids[-1]], '--', color='grey', alpha=grid_alpha)
def sample_z(self, T): # sample the downsampled_t-invariant markov chain only # TODO: may want to expand to other cases assert isinstance(self.transition, StationaryTransition), \ "Sampling the makov chain only supports for stationary transition" z = torch.empty(T, dtype=torch.int, device=self.device) pi0 = get_np(self.init_state_distn.probs) z[0] = npr.choice(self.K, p=pi0) P = get_np(self.transition.stationary_transition_matrix) # (K, K) for t in range(1, T): z[t] = npr.choice(self.K, p=P[z[t - 1]]) return z
def plot_dynamics(weighted_corner_vecs, animal, x_grids, y_grids, K, scale=0.1, percentage=None, title=None, grid_alpha=1): """ This is for the illustration of the dynamics of the discrete grid model. Probably want to make the case of K>8 """ if isinstance(x_grids, torch.Tensor): x_grids = get_np(x_grids) if isinstance(y_grids, torch.Tensor): y_grids = get_np(y_grids) result_corner_vecs = np.sum(weighted_corner_vecs, axis=2) n_x = len(x_grids) - 1 n_y = len(y_grids) - 1 grid_centers = np.array( [[1 / 2 * (x_grids[i] + x_grids[i + 1]), 1 / 2 * (y_grids[j] + y_grids[j + 1])] for i in range(n_x) for j in range(n_y)]) Df = 4 def plot_dynamics_k(k): add_grid(x_grids, y_grids, grid_alpha=grid_alpha) add_percentage(k, percentage=percentage, grid_centers=grid_centers) for df in range(Df): plt.quiver(grid_centers[:, 0], grid_centers[:, 1], weighted_corner_vecs[:, k, df, 0], weighted_corner_vecs[:, k, df, 1], units='xy', scale=scale, width=2, alpha=0.5) plt.quiver(grid_centers[:, 0], grid_centers[:, 1], result_corner_vecs[:, k, 0], result_corner_vecs[:, k, 1], units='xy', scale=scale, width=2, color='red', alpha=0.5) plt.title("K={}, ".format(k) + animal, fontsize=20) if K <= 4: plt.figure(figsize=(20, 5)) if title is not None: plt.suptitle(title) for k in range(K): plt.subplot(1, K, k+1) plot_dynamics_k(k) elif 4 < K <= 8: plt.figure(figsize=(20, 10)) if title is not None: plt.suptitle(title) for k in range(K): plt.subplot(2, int(K/2)+1, k+1) plot_dynamics_k(k) plt.tight_layout(rect=[0, 0.03, 1, 0.95])
def get_all_angles(data, x_grids, y_grids, device=torch.device('cpu')): dXY = data[1:] - data[:-1] if isinstance(dXY, torch.Tensor): dXY = get_np(dXY) return get_all_angles_from_quiver(data[:-1], dXY, x_grids, y_grids, device=device)
def get_gpt_idx_and_grid_for_single(self, point): """ :param point: (2,) :return: gpt_idx a list of length 4 """ assert point.shape == (2, ) find = False gpt_idx = [] grid_idx = 0 l_y = len(self.y_grids) for i in range(len(self.x_grids) - 1): for j in range(len(self.y_grids) - 1): cond_x = self.x_grids[i] <= point[0] <= self.x_grids[i + 1] cond_y = self.y_grids[j] <= point[1] <= self.y_grids[j + 1] if cond_x & cond_y: find = True gpt_idx.append(i * l_y + j) # Q11 gpt_idx.append(i * l_y + j + 1) # (Q12) gpt_idx.append((i + 1) * l_y + j) # (Q21) gpt_idx.append((i + 1) * l_y + j + 1) # Q22 break grid_idx += 1 if find: break if not find: raise ValueError("value {} out of the grid world.".format( get_np(point))) return gpt_idx, grid_idx
def get_gridpoints_idx_for_single_point(self, point): """ Q12 -- Q22 | Q11 -- Q21 :param point: (2,) :return: idx (4, ) idx for Q11, Q12, Q21, Q22) """ assert point.shape == (2, ) find = False for i in range(self.n_x): for j in range(self.n_y): cond_x = self.x_grids[i] <= point[0] <= self.x_grids[i + 1] cond_y = self.y_grids[j] <= point[1] <= self.y_grids[j + 1] if cond_x & cond_y: find = True idx = [ i * self.ly + j, i * self.ly + j + 1, (i + 1) * self.ly + j, (i + 1) * self.ly + j + 1 ] break if find: break if not find: raise ValueError("value {} out of the grid world.".format( get_np(point))) return idx
def get_gridpoints_idx_for_single(self, point): """ :param point: (2,) :return: idx (n_gps, 4) """ assert point.shape == (2, ) find = False idx = torch.zeros((self.GP, 4), dtype=torch.float64, device=self.device) l_y = len(self.y_grids) for i in range(len(self.x_grids)-1): for j in range(len(self.y_grids)-1): cond_x = self.x_grids[i] <= point[0] <= self.x_grids[i+1] cond_y = self.y_grids[j] <= point[1] <= self.y_grids[j+1] if cond_x & cond_y: find = True idx[i*l_y+j, 0] = 1 # Q11 idx[i*l_y+j+1, 1] = 1 # Q12 idx[(i+1)*l_y+j, 2] = 1 # Q21 idx[(i+1)*l_y+j+1, 3] = 1 # Q22 break if find: break if not find: raise ValueError("value {} out of the grid world.".format(get_np(point))) return idx
def plot_transition(transition): """ plot the heatmap for one transition matrix :param transition: (K, K) :return: """ sns.heatmap(get_np(transition), vmin=0, vmax=1, cmap="BuGn", square=True)
def plot_space_dist(data, x_grids, y_grids, grid_alpha=1): # TODO: there are some T, D = data.shape assert D == 2 or D == 4 if isinstance(data, torch.Tensor): data = get_np(data) n_levels = int(T / 36) if D == 4: plt.figure(figsize=(15, 7)) plt.subplot(1, 2, 1) sns.kdeplot(data[:, 0], data[:, 1], n_levels=n_levels) add_grid(x_grids, y_grids, grid_alpha=grid_alpha) plt.title("virgin") plt.subplot(1, 2, 2) sns.kdeplot(data[:, 2], data[:, 3], n_levels=n_levels) add_grid(x_grids, y_grids, grid_alpha=grid_alpha) plt.title("mother") plt.tight_layout(rect=[0, 0.03, 1, 0.95]) else: plt.figure(figsize=(8,7)) sns.kdeplot(data[:,0], data[:,1], n_levels=n_levels) add_grid(x_grids, y_grids)
def get_z_percentage_by_grid(masks_a, z, K, G): masks_z_a = np.array([(z[:-1] + 1) * get_np(masks_a[g]) for g in range(G)]) # (G, K) For each grid g, number of data in that grid = k grid_z_a = np.array([[sum(masks_z_a[g] == k) for k in range(1, K + 1)] for g in range(G)]) grid_z_a_percentage = grid_z_a / (grid_z_a.sum(axis=1)[:, None] + 1e-6) return grid_z_a_percentage
def __init__(self, K, D, x_grids, y_grids, Df, feature_vec_func, tran=None, acc_factor=2, lags=1, use_log_prior=False, no_boundary_prior=False, add_log_diagonal_prior=False, log_prior_sigma_sq=-np.log(1e3), device=torch.device('cpu'), version=1): assert lags == 1, "lags should be 1 for lineargrid with_noise." super(LinearGridTransformation, self).__init__(K, D) self.version = version self.d = int(self.D / 2) self.device = device self.x_grids = check_and_convert_to_tensor(x_grids, dtype=torch.float64, device=self.device) # [x_0, x_1, ..., x_m] self.y_grids = check_and_convert_to_tensor(y_grids, dtype=torch.float64, device=self.device) # a list [y_0, y_1, ..., y_n] self.n_x = len(x_grids) - 1 self.n_y = len(y_grids) - 1 # shape: (d, n_gps) self.gridpoints = torch.tensor([(x_grid, y_grid) for x_grid in self.x_grids for y_grid in self.y_grids], device=device) self.gridpoints = torch.transpose(self.gridpoints, 0, 1) # number of basis grid points self.GP = self.gridpoints.shape[1] self.Df = Df self.feature_vec_func = feature_vec_func if tran is not None: assert isinstance(tran, LinearGridTransformation) self.use_log_prior = tran.use_log_prior self.add_log_diagonal_prior = tran.add_log_diagonal_prior self.no_boundary_prior = tran.no_boundary_prior self.log_prior_sigma_sq = torch.tensor(get_np(tran.log_prior_sigma_sq), dtype=torch.float64, device=self.device) self.acc_factor = tran.acc_factor self.Ws = torch.tensor(get_np(tran.Ws), dtype=torch.float64, requires_grad=True, device=self.device) else: self.use_log_prior = use_log_prior self.add_log_diagonal_prior = add_log_diagonal_prior self.no_boundary_prior = no_boundary_prior self.log_prior_sigma_sq = torch.tensor(log_prior_sigma_sq, dtype=torch.float64, device=device) self.acc_factor = acc_factor self.Ws = torch.rand(self.K, 2, self.GP, self.Df, dtype=torch.float64, requires_grad=True, device=self.device)
def sample(self, sample_shape=torch.Size()): """ :param sample_shape: currently consider one :return: (D, ) """ # first, transform to standard normal scale = np.exp(get_np(self.log_sigmas)) # (D, ) loc = get_np(self.mus) # (D, ) bb = (get_np(self.bounds[..., 1]) - loc) / scale # (D, ) aa = (get_np(self.bounds[..., 0]) - loc) / scale # (D, ) if sample_shape == (): D = bb.shape[0] samples = truncnorm.rvs(a=aa, b=bb, loc=loc, scale=scale, size=(D)) # some adhoc way to fix the infinity in sample samples[samples == -np.inf] = get_np( self.bounds[..., 1])[samples == -np.inf] samples[samples == np.inf] = get_np( self.bounds[..., 0])[samples == np.inf] else: samples = truncnorm.rvs(a=aa, b=bb, loc=loc, scale=scale, size=sample_shape) samples = torch.tensor(samples, dtype=torch.float64) return samples
def sample_condition_on_zs(self, zs, x0=None, transformation=False, return_np=True, **kwargs): """ Given a z sequence, generate samples condition on this sequence. :param zs: (T, ) :param x0: shape (D,) :param return_np: return np.ndarray or torch.tensor :return: generated samples (T, D) """ zs = check_and_convert_to_tensor(zs, dtype=torch.int, device=self.device) T = zs.shape[0] assert T > 0 dtype = torch.float64 xs = torch.zeros((T, self.D), dtype=dtype) if T == 1: if x0 is not None: print("Nothing to sample") return else: return self.observation.sample_x(zs[0], with_noise=transformation) if x0 is None: x0 = self.observation.sample_x(zs[0], with_noise=transformation, return_np=False) else: x0 = check_and_convert_to_tensor(x0, dtype=dtype, device=self.device) assert x0.shape == (self.D, ) xs[0] = x0 for t in np.arange(1, T): x_t = self.observation.sample_x(zs[t], xihst=xs[:t], with_noise=transformation, return_np=False, **kwargs) xs[t] = x_t if return_np: return get_np(xs) return xs
def plot_transition(self, ax, cbar_ax, transition): """ plot the heatmap for one transition matrix :param transition: (K, K) :return: """ ax.clear() sns.heatmap(get_np(transition), ax=ax, vmin=0, vmax=1, cmap="BuGn", square=True, cbar_ax=cbar_ax)
def get_all_angles_from_quiver(XY, dXY, x_grids, y_grids, device=torch.device('cpu')): # XY and dXY should have the same shape if isinstance(XY, np.ndarray): XY = torch.tensor(XY, dtype=torch.float64, device=device) _, D = XY.shape if D == 4: masks_a, masks_b = get_masks_for_two_animals(XY, x_grids, y_grids) masks_a = get_np(masks_a) masks_b = get_np(masks_b) angles_a = [] angles_b = [] G = (len(x_grids) - 1) * (len(y_grids) - 1) for g in range(G): dXY_a_g = dXY[masks_a[g] == 1][:, 0:2] dXY_b_g = dXY[masks_b[g] == 1][:, 2:4] angles_a.append(get_angles_single_from_quiver(dXY_a_g)) angles_b.append(get_angles_single_from_quiver(dXY_b_g)) return angles_a, angles_b elif D == 2: masks_a = get_masks_for_single_animal(XY, x_grids, y_grids) masks_a = get_np(masks_a) angles_a = [] G = (len(x_grids) - 1) * (len(y_grids) - 1) for g in range(G): dXY_a_g = dXY[masks_a[g] == 1][:, 0:2] angles_a.append(get_angles_single_from_quiver(dXY_a_g)) return angles_a else: raise ValueError("Invalid data shape")
def plot_mouse(data, alpha=.8, title=None, xlim=None, ylim=None, mouse='both'): if isinstance(data, torch.Tensor): data = get_np(data) if title is not None: plt.title(title) _, D = data.shape assert D == 4 or D == 2 if D == 4: plt.plot(data[:, 0], data[:, 1], label='virgin', alpha=alpha) plt.plot(data[:, 2], data[:, 3], label='mother', alpha=alpha) else: plt.plot(data[:, 0], data[:, 1], alpha=alpha, label=mouse) if xlim is not None: plt.xlim(xlim) if ylim is not None: plt.ylim(ylim)
def plot_2d_time_plot_condition_on_all_zs(data, z, K, title, time_start=None, time_end=None, size=0.5): data = get_np(data) T, _ = data.shape time_start = time_start if time_start else 0 time_end = time_end if time_end else T plt.figure(figsize=(30, 2 * K)) if title is not None: plt.suptitle(title) for k in range(K): plt.subplot(K, 1, k + 1) plot_2d_time_plot_condition_on_z(data, z, k, time_start, time_end, size) plt.tight_layout(rect=[0, 0.03, 1, 0.95])
def sample_x(self, z, xhist=None, with_noise=False, return_np=True, **memory_kwargs): """ :param z: an integer :param xhist: (T_pre, D) :param with_noise: return transformed value as sample value, instead of sampling :param return_np: boolean, whether return np.ndarray or torch.tensor :return: one sample (D, ) """ if xhist is None or xhist.shape[0] == 0: mu = self.mus_init[z] # (D,) log_sigma = self.log_sigmas_init[z] else: # sample from the autoregressive distribution T_pre = xhist.shape[0] if T_pre < self.lags: mu = self.transformation.transform_condition_on_z( z, xhist, **memory_kwargs) # (D, ) else: mu = self.transformation.transform_condition_on_z( z, xhist[-self.lags:], **memory_kwargs) # (D, ) assert mu.shape == (self.D, ) log_sigma = self.log_sigmas[z] if with_noise: samples = mu else: dist = TruncatedNormal(mus=mu, log_sigmas=log_sigma, bounds=self.bounds) samples = dist.sample() for d in range(self.D): samples[d] = clip(samples[d], self.bounds[d]) if return_np: return get_np(samples) return samples
def plot_data_condition_on_all_zs(data, z, K, size=2, alpha=0.3): data = get_np(data) n_col = 5 n_row = int(K/n_col) if K % n_col > 0: n_row += 1 plt.figure(figsize=(20, 4*n_row)) title = "spatial occupation under different hidden states" if title is not None: plt.suptitle(title) for k in range(K): plt.subplot(n_row, n_col, k+1) plot_data_condition_on_zk(data, z, k, size=size, alpha=alpha) plt.title('K={} '.format(k)) plt.tight_layout(rect=[0, 0.03, 1, 0.95])
def get_grid_idx_for_single(self, point): """ :param point: (2,) :return: grid idx: a scalar """ assert point.shape == (2,), point.shape find = False grid_idx = 0 for i in range(len(self.x_grids) - 1): for j in range(len(self.y_grids) - 1): cond_x = self.x_grids[i] <= point[0] <= self.x_grids[i + 1] cond_y = self.y_grids[j] <= point[1] <= self.y_grids[j + 1] if cond_x & cond_y: find = True break grid_idx += 1 if find: break if not find: raise ValueError("value {} out of the grid world.".format(get_np(point))) return grid_idx
def k_step_prediction_for_lstm_based_model(model, model_z, data, k=0, feature_vecs=None): data = check_and_convert_to_tensor(data) T, D = data.shape lstm_states = {} x_predict_arr = [] if k == 0: if feature_vecs is None: print("Did not provide memory information") return k_step_prediction(model, model_z, data) else: feature_vecs_a, feature_vecs_b = feature_vecs x_predict = model.observation.sample_x(model_z[0], data[:0], return_np=True) x_predict_arr.append(x_predict) for t in range(1, data.shape[0]): feature_vec_t = (feature_vecs_a[t - 1:t], feature_vecs_b[t - 1:t]) x_predict = model.observation.sample_x( model_z[t], data[:t], return_np=True, with_noise=True, feature_vec=feature_vec_t, lstm_states=lstm_states) x_predict_arr.append(x_predict) else: assert k > 0 # neglects t = 0 since there is no history if T <= k: raise ValueError("Please input k such that k < {}.".format(T)) for t in range(1, T - k + 1): # sample k steps forward # first step use real value z, x = model.sample(1, prefix=(model_z[t - 1:t], data[t - 1:t]), return_np=False, with_noise=True, lstm_states=lstm_states) # last k-1 steps use sampled value if k >= 1: sampled_lstm_states = dict(h_t=lstm_states["h_t"], c_t=lstm_states["c_t"]) for i in range(k - 1): z, x = model.sample(1, prefix=(z, x), return_np=False, with_noise=True, lstm_states=sampled_lstm_states) assert x.shape == (1, D) x_predict_arr.append(get_np(x[0])) x_predict_arr = np.array(x_predict_arr) assert x_predict_arr.shape == (T - k, D) return x_predict_arr
def main(job_name, cuda_num, downsample_n, filter_traj, gp_version, load_model, load_model_dir, load_opt_dir, transition, sticky_alpha, sticky_kappa, acc_factor, k, x_grids, y_grids, n_x, n_y, rs_factor, rs, train_rs, train_model, pbar_update_interval, video_clips, held_out_proportion, torch_seed, np_seed, list_of_num_iters, ckpts_not_to_save, list_of_lr, list_of_k_steps, sample_t, quiver_scale): if job_name is None: raise ValueError("Please provide the job name.") cuda_num = int(cuda_num) device = torch.device( "cuda:{}".format(cuda_num) if torch.cuda.is_available() else "cpu") print("Using device {} \n\n".format(device)) K = k sample_T = sample_t rs_factor = np.array([float(x) for x in rs_factor.split(",")]) if rs_factor[0] == 0 and rs_factor[1] == 0: rs_factor = None rs = float(rs) video_clip_start, video_clip_end = [ float(x) for x in video_clips.split(",") ] list_of_num_iters = [int(x) for x in list_of_num_iters.split(",")] list_of_lr = [float(x) for x in list_of_lr.split(",")] list_of_k_steps = [int(x) for x in list_of_k_steps.split(",")] assert len(list_of_num_iters) == len( list_of_lr ), "Length of list_of_num_iters must match length of list_of_lr." for lr in list_of_lr: if lr > 1: raise ValueError("Learning rate should not be larger than 1!") ckpts_not_to_save = [int(x) for x in ckpts_not_to_save.split(',') ] if ckpts_not_to_save else [] repo = git.Repo( '.', search_parent_directories=True) # SocialBehaviorectories=True) repo_dir = repo.working_tree_dir # SocialBehavior torch.manual_seed(torch_seed) np.random.seed(np_seed) ########################## data ######################## data_dir = repo_dir + '/SocialBehaviorptc/data/trajs_all' trajs = joblib.load(data_dir) traj = trajs[int(36000 * video_clip_start):int(36000 * video_clip_end)] traj = downsample(traj, downsample_n) if filter_traj: traj = filter_traj_by_speed(traj, q1=0.99, q2=0.99) data = torch.tensor(traj, dtype=torch.float64, device=device) assert 0 <= held_out_proportion <= 0.4, \ "held_out-portion should be between 0 and 0.4 (inclusive), but is {}".format(held_out_proportion) T = data.shape[0] breakpoint = int(T * (1 - held_out_proportion)) training_data = data[:breakpoint] valid_data = data[breakpoint:] ######################### model #################### # model D = 4 M = 0 Df = 4 if load_model: print("Loading the model from ", load_model_dir) model = joblib.load(load_model_dir) tran = model.observation.transformation K = model.K n_x = len(tran.x_grids) - 1 n_y = len(tran.y_grids) - 1 acc_factor = tran.acc_factor else: print("Creating the model...") bounds = np.array([[ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX], [ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX]]) # grids if x_grids is None: x_grid_gap = (ARENA_XMAX - ARENA_XMIN) / n_x x_grids = np.array( [ARENA_XMIN + i * x_grid_gap for i in range(n_x + 1)]) else: x_grids = np.array([float(x) for x in x_grids.split(",")]) n_x = len(x_grids) - 1 if y_grids is None: y_grid_gap = (ARENA_YMAX - ARENA_YMIN) / n_y y_grids = np.array( [ARENA_YMIN + i * y_grid_gap for i in range(n_y + 1)]) else: y_grids = np.array([float(x) for x in y_grids.split(",")]) n_y = len(y_grids) - 1 if acc_factor is None: acc_factor = downsample_n * 10 tran = GPGridTransformation(K=K, D=D, x_grids=x_grids, y_grids=y_grids, Df=Df, feature_vec_func=f_corner_vec_func, acc_factor=acc_factor, rs_factor=rs_factor, rs=None, train_rs=train_rs, device=device, version=gp_version) obs = ARTruncatedNormalObservation(K=K, D=D, M=M, lags=1, bounds=bounds, transformation=tran, device=device) if transition == 'sticky': transition_kwargs = dict(alpha=sticky_alpha, kappa=sticky_kappa) else: transition_kwargs = None model = HMM(K=K, D=D, M=M, transition=transition, observation=obs, transition_kwargs=transition_kwargs, device=device) model.observation.mus_init = training_data[0] * torch.ones( K, D, dtype=torch.float64, device=device) # save experiment params exp_params = { "job_name": job_name, 'downsample_n': downsample_n, "filter_traj": filter_traj, "load_model": load_model, "gp_version": gp_version, "load_model_dir": load_model_dir, "load_opt_dir": load_opt_dir, "transition": transition, "sticky_alpha": sticky_alpha, "sticky_kappa": sticky_kappa, "acc_factor": acc_factor, "K": K, "x_grids": x_grids, "y_grids": y_grids, "n_x": n_x, "n_y": n_y, "rs_factor": get_np(tran.rs_factor), "rs": rs, "train_rs": train_rs, "train_model": train_model, "pbar_update_interval": pbar_update_interval, "video_clip_start": video_clip_start, "video_clip_end": video_clip_end, "held_out_proportion": held_out_proportion, "torch_seed": torch_seed, "np_seed": np_seed, "list_of_num_iters": list_of_num_iters, "list_of_lr": list_of_lr, "list_of_k_steps": list_of_k_steps, "sample_T": sample_T, "quiver_scale": quiver_scale } print("Experiment params:") print(exp_params) rslt_dir = addDateTime("rslts/gpgrid/" + job_name) rslt_dir = os.path.join(repo_dir, rslt_dir) if not os.path.exists(rslt_dir): os.makedirs(rslt_dir) print("Making result directory...") print("Saving to rlst_dir: ", rslt_dir) with open(rslt_dir + "/exp_params.json", "w") as f: json.dump(exp_params, f, indent=4, cls=NumpyEncoder) # compute memory print("Computing memory...") def get_memory_kwargs(data, train_rs): feature_vecs_a = f_corner_vec_func(data[:-1, 0:2]) feature_vecs_b = f_corner_vec_func(data[:-1, 2:4]) gpt_idx_a, grid_idx_a = tran.get_gpt_idx_and_grid_idx_for_batch( data[:-1, 0:2]) gpt_idx_b, grid_idx_b = tran.get_gpt_idx_and_grid_idx_for_batch( data[:-1, 2:4]) if train_rs: nearby_gpts_a = tran.gridpoints[gpt_idx_a] dist_sq_a = (data[:-1, None, 0:2] - nearby_gpts_a)**2 nearby_gpts_b = tran.gridpoints[gpt_idx_b] dist_sq_b = (data[:-1, None, 2:4] - nearby_gpts_b)**2 return dict(feature_vecs_a=feature_vecs_a, feature_vecs_b=feature_vecs_b, gpt_idx_a=gpt_idx_a, gpt_idx_b=gpt_idx_b, grid_idx_a=grid_idx_a, grid_idx_b=grid_idx_b, dist_sq_a=dist_sq_a, dist_sq_b=dist_sq_b) else: coeff_a = tran.get_gp_coefficients(data[:-1, 0:2], 0, gpt_idx_a, grid_idx_a) coeff_b = tran.get_gp_coefficients(data[:-1, 2:4], 0, gpt_idx_b, grid_idx_b) return dict(feature_vecs_a=feature_vecs_a, feature_vecs_b=feature_vecs_b, gpt_idx_a=gpt_idx_a, gpt_idx_b=gpt_idx_b, grid_idx_a=grid_idx_a, grid_idx_b=grid_idx_b, coeff_a=coeff_a, coeff_b=coeff_b) memory_kwargs = get_memory_kwargs(training_data, train_rs) valid_data_memory_kwargs = get_memory_kwargs(valid_data, train_rs) log_prob = model.log_likelihood(training_data, **memory_kwargs) ##################### training ############################ if train_model: print("start training") list_of_losses = [] if load_opt_dir != "": opt = joblib.load(load_opt_dir) else: opt = None for i, (num_iters, lr) in enumerate(zip(list_of_num_iters, list_of_lr)): training_losses, opt, valid_losses = model.fit( training_data, optimizer=opt, method='adam', num_iters=num_iters, lr=lr, pbar_update_interval=pbar_update_interval, valid_data=valid_data, valid_data_memory_kwargs=valid_data_memory_kwargs, **memory_kwargs) list_of_losses.append(training_losses) checkpoint_dir = rslt_dir + "/checkpoint_{}".format(i) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) print("Creating checkpoint_{} directory...".format(i)) # save model and opt joblib.dump(model, checkpoint_dir + "/model") joblib.dump(opt, checkpoint_dir + "/optimizer") # save losses losses = dict(training_loss=training_losses, valid_loss=valid_losses) joblib.dump(losses, checkpoint_dir + "/losses") plt.figure() plt.plot(training_losses) plt.title("training loss") plt.savefig(checkpoint_dir + "/training_losses.jpg") plt.close() plt.figure() plt.plot(valid_losses) plt.title("validation loss") plt.savefig(checkpoint_dir + "/valid_losses.jpg") plt.close() # save rest if i in ckpts_not_to_save: print("ckpt {}: skip!\n".format(i)) continue with torch.no_grad(): rslt_saving(rslt_dir=checkpoint_dir, model=model, data=training_data, memory_kwargs=memory_kwargs, list_of_k_steps=list_of_k_steps, sample_T=sample_T, quiver_scale=quiver_scale, valid_data=valid_data, valid_data_memory_kwargs=valid_data_memory_kwargs, device=device) else: # only save the results rslt_saving(rslt_dir=rslt_dir, model=model, data=training_data, memory_kwargs=memory_kwargs, list_of_k_steps=list_of_k_steps, sample_T=sample_T, quiver_scale=quiver_scale, valid_data=valid_data, valid_data_memory_kwargs=valid_data_memory_kwargs, device=device) print("Finish running!")
def rslt_saving(rslt_dir, model, data, mouse, sample_T, train_model, losses, quiver_scale, x_grids, y_grids): tran = model.observation.transformation _, D = data.shape assert D == 2 or D == 4, "D must be either 2 or 4." n_x = len(x_grids) - 1 n_y = len(y_grids) - 1 K = model.K #################### inference ########################### print("\ninferring most likely states...") z = model.most_likely_states(data) print("0 step prediction") if data.shape[0] <= 5000: data_to_predict = data else: data_to_predict = data[-5000:] x_predict = k_step_prediction(model, z, data_to_predict) x_predict_err = np.mean(np.abs(x_predict - data_to_predict.numpy()), axis=0) print("5 step prediction") x_predict_5 = k_step_prediction(model, z, data_to_predict, k=5) x_predict_5_err = np.mean(np.abs(x_predict_5 - data_to_predict[5:].numpy()), axis=0) ################### samples ######################### sample_z, sample_x = model.sample(sample_T) center_z = torch.tensor([0], dtype=torch.int) if D == 4: center_x = torch.tensor([[150, 190, 200, 200]], dtype=torch.float64) else: center_x = torch.tensor([[150, 190]], dtype=torch.float64) sample_z_center, sample_x_center = model.sample(sample_T, prefix=(center_z, center_x)) ################## dynamics ##################### # quiver XX, YY = np.meshgrid(np.linspace(20, 310, 30), np.linspace(0, 380, 30)) XY = np.column_stack( (np.ravel(XX), np.ravel(YY))) # shape (900,2) grid values if D == 2: XY_grids = XY else: XY_grids = np.concatenate((XY, XY), axis=1) XY_next = tran.transform(torch.tensor(XY_grids, dtype=torch.float64)) dXY = XY_next.detach().numpy() - XY_grids[:, None] #################### saving ############################## print("begin saving...") # save summary avg_transform_speed = np.average(np.abs(dXY), axis=0) avg_sample_speed = np.average(np.abs(np.diff(sample_x, axis=0)), axis=0) avg_sample_center_speed = np.average(np.abs( np.diff(sample_x_center, axis=0)), axis=0) avg_data_speed = np.average(np.abs(np.diff(data.numpy(), axis=0)), axis=0) transition_matrix = model.transition.stationary_transition_matrix if transition_matrix.requires_grad: transition_matrix = transition_matrix.detach().numpy() else: transition_matrix = transition_matrix.numpy() cluster_centers = get_np(tran.mus_loc) summary_dict = { "init_dist": model.init_dist.detach().numpy(), "transition_matrix": transition_matrix, "x_predict_err": x_predict_err, "x_predict_5_err": x_predict_5_err, "mus": cluster_centers, "variance": torch.exp(model.observation.log_sigmas).detach().numpy(), "log_likes": model.log_likelihood(data).detach().numpy(), "avg_transform_speed": avg_transform_speed, "avg_data_speed": avg_data_speed, "avg_sample_speed": avg_sample_speed, "avg_sample_center_speed": avg_sample_center_speed } with open(rslt_dir + "/summary.json", "w") as f: json.dump(summary_dict, f, indent=4, cls=NumpyEncoder) # save numbers saving_dict = { "z": z, "x_predict": x_predict, "x_predict_5": x_predict_5, "sample_z": sample_z, "sample_x": sample_x, "sample_z_center": sample_z_center, "sample_x_center": sample_x_center } if train_model: saving_dict['losses'] = losses plt.figure() plt.plot(losses) plt.savefig(rslt_dir + "/losses.jpg") plt.close() joblib.dump(saving_dict, rslt_dir + "/numbers") # save figures plot_z(z, K, title="most likely z for the ground truth") plt.savefig(rslt_dir + "/z.jpg") plt.close() if not os.path.exists(rslt_dir + "/samples"): os.makedirs(rslt_dir + "/samples") print("Making samples directory...") plot_z(sample_z, K, title="sample") plt.savefig(rslt_dir + "/samples/sample_z_{}.jpg".format(sample_T)) plt.close() plot_z(sample_z_center, K, title="sample (starting from center)") plt.savefig(rslt_dir + "/samples/sample_z_center_{}.jpg".format(sample_T)) plt.close() plt.figure(figsize=(4, 4)) plot_mouse(data, title="ground truth_{}".format(mouse), xlim=[ARENA_XMIN - 20, ARENA_XMAX + 20], ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20]) plt.legend() plt.savefig(rslt_dir + "/samples/ground_truth.jpg") plt.close() plt.figure(figsize=(4, 4)) plot_mouse(sample_x, title="sample_{}".format(mouse), xlim=[ARENA_XMIN - 20, ARENA_XMAX + 20], ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20]) plt.legend() plt.savefig(rslt_dir + "/samples/sample_x_{}.jpg".format(sample_T)) plt.close() plt.figure(figsize=(4, 4)) plot_mouse(sample_x_center, title="sample (starting from center)_{}".format(mouse), xlim=[ARENA_XMIN - 20, ARENA_XMAX + 20], ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20]) plt.legend() plt.savefig(rslt_dir + "/samples/sample_x_center_{}.jpg".format(sample_T)) plt.close() plot_realdata_quiver(data, z, K, x_grids, y_grids, title="ground truth", cluster_centers=cluster_centers) plt.savefig(rslt_dir + "/samples/quiver_ground_truth.jpg", dpi=200) plot_realdata_quiver(sample_x, sample_z, K, x_grids, y_grids, title="sample", cluster_centers=cluster_centers) plt.savefig(rslt_dir + "/samples/quiver_sample_x_{}.jpg".format(sample_T), dpi=200) plt.close() plot_realdata_quiver(sample_x_center, sample_z_center, K, x_grids, y_grids, title="sample (starting from center)", cluster_centers=cluster_centers) plt.savefig(rslt_dir + "/samples/quiver_sample_x_center_{}.jpg".format(sample_T), dpi=200) plt.close() # plot mus plot_cluster_centers(cluster_centers, x_grids, y_grids) plt.savefig(rslt_dir + "/samples/cluster_centers.jpg", dpi=200) if not os.path.exists(rslt_dir + "/dynamics"): os.makedirs(rslt_dir + "/dynamics") print("Making dynamics directory...") if D == 2: plot_quiver(XY_grids, dXY, mouse, K=K, scale=quiver_scale, alpha=0.9, title="quiver ({})".format(mouse), x_grids=x_grids, y_grids=y_grids, grid_alpha=0.2) plt.savefig(rslt_dir + "/dynamics/quiver_{}.jpg".format(mouse), dpi=200) plt.close() else: plot_quiver(XY_grids[:, 0:2], dXY[..., 0:2], 'virgin', K=K, scale=quiver_scale, alpha=0.9, title="quiver (virgin)", x_grids=x_grids, y_grids=y_grids, grid_alpha=0.2) plt.savefig(rslt_dir + "/dynamics/quiver_a.jpg", dpi=200) plt.close() plot_quiver(XY_grids[:, 2:4], dXY[..., 2:4], 'mother', K=K, scale=quiver_scale, alpha=0.9, title="quiver (mother)", x_grids=x_grids, y_grids=y_grids, grid_alpha=0.2) plt.savefig(rslt_dir + "/dynamics/quiver_b.jpg", dpi=200) plt.close() if not os.path.exists(rslt_dir + "/distributions"): os.makedirs(rslt_dir + "/distributions") print("Making distributions directory...") if D == 4: data_angles_a, data_angles_b = get_all_angles(data, x_grids, y_grids) sample_angles_a, sample_angles_b = get_all_angles( sample_x, x_grids, y_grids) sample_x_center_angles_a, sample_x_center_angles_b = get_all_angles( sample_x_center, x_grids, y_grids) plot_list_of_angles( [data_angles_a, sample_angles_a, sample_x_center_angles_a], ['data', 'sample', 'sample_c'], "direction distribution (virgin)", n_x, n_y) plt.savefig(rslt_dir + "/distributions/angles_a.jpg") plt.close() plot_list_of_angles( [data_angles_b, sample_angles_b, sample_x_center_angles_b], ['data', 'sample', 'sample_c'], "direction distribution (mother)", n_x, n_y) plt.savefig(rslt_dir + "/distributions/angles_b.jpg") plt.close() data_speed_a, data_speed_b = get_speed(data, x_grids, y_grids) sample_speed_a, sample_speed_b = get_speed(sample_x, x_grids, y_grids) sample_x_center_speed_a, sample_x_center_speed_b = get_speed( sample_x_center, x_grids, y_grids) plot_list_of_speed( [data_speed_a, sample_speed_a, sample_x_center_speed_a], ['data', 'sample', 'sample_c'], "speed distribution (virgin)", n_x, n_y) plt.savefig(rslt_dir + "/distributions/speed_a.jpg") plt.close() plot_list_of_speed( [data_speed_b, sample_speed_b, sample_x_center_speed_b], ['data', 'sample', 'sample_c'], "speed distribution (mother)", n_x, n_y) plt.savefig(rslt_dir + "/distributions/speed_b.jpg") plt.close() else: data_angles_a = get_all_angles(data, x_grids, y_grids) sample_angles_a = get_all_angles(sample_x, x_grids, y_grids) sample_x_center_angles_a = get_all_angles(sample_x_center, x_grids, y_grids) plot_list_of_angles( [data_angles_a, sample_angles_a, sample_x_center_angles_a], ['data', 'sample', 'sample_c'], "direction distribution (virgin)", n_x, n_y) plt.savefig(rslt_dir + "/distributions/angles_{}.jpg".format(mouse)) plt.close() data_speed_a = get_speed(data, x_grids, y_grids) sample_speed_a = get_speed(sample_x, x_grids, y_grids) sample_x_center_speed_a = get_speed(sample_x_center, x_grids, y_grids) plot_list_of_speed( [data_speed_a, sample_speed_a, sample_x_center_speed_a], ['data', 'sample', 'sample_c'], "speed distribution (virgin)", n_x, n_y) plt.savefig(rslt_dir + "/distributions/speed_{}.jpg".format(mouse)) plt.close() try: if 100 < data.shape[0] <= 36000: plot_space_dist(data, x_grids, y_grids) elif data.shape[0] > 36000: plot_space_dist(data[:36000], x_grids, y_grids) plt.savefig(rslt_dir + "/distributions/space_data.jpg") plt.close() if 100 < sample_x.shape[0] <= 36000: plot_space_dist(sample_x, x_grids, y_grids) elif sample_x.shape[0] > 36000: plot_space_dist(sample_x[:36000], x_grids, y_grids) plt.savefig(rslt_dir + "/distributions/space_sample_x.jpg") plt.close() if 100 < sample_x_center.shape[0] <= 36000: plot_space_dist(sample_x_center, x_grids, y_grids) elif sample_x_center.shape[0] > 36000: plot_space_dist(sample_x_center[:36000], x_grids, y_grids) plt.savefig(rslt_dir + "/distributions/space_sample_x_center.jpg") plt.close() except: print("plot_space_dist unsuccessful")
def fit(self, datas, inputs=None, optimizer=None, method='adam', num_iters=1000, lr=0.001, pbar_update_interval=10, valid_data=None, transition_memory_kwargs=None, valid_data_transition_memory_kwargs=None, valid_data_memory_kwargs=None, **memory_kwargs): pbar = trange(num_iters, file=sys.stdout) if optimizer is None: if method == 'adam': optimizer = torch.optim.Adam(self.trainable_params, lr=lr) elif method == 'sgd': optimizer = torch.optim.SGD(self.trainable_params, lr=lr) else: raise ValueError("Method must be chosen from adam and sgd.") else: assert isinstance(optimizer, (torch.optim.SGD, torch.optim.Adam)), \ "Optimizer must be chosen from SGD or Adam" for param_group in optimizer.param_groups: param_group['lr'] = lr losses = [] if valid_data is not None: valid_losses = [] valid_data_memory_kwargs = valid_data_memory_kwargs if valid_data_memory_kwargs else {} for i in np.arange(num_iters): optimizer.zero_grad() loss = self.loss(datas, inputs, transition_memory_kwargs=transition_memory_kwargs, **memory_kwargs) loss.backward() optimizer.step() loss = get_np(loss) losses.append(loss) if valid_data is not None: if len(valid_data) > 0: with torch.no_grad(): valid_losses.append( get_np( self.loss(valid_data, transition_memory_kwargs= valid_data_transition_memory_kwargs, **valid_data_memory_kwargs))) if i % pbar_update_interval == 0: with nostdout(): pbar.set_description('iter {} loss {:.2f}'.format(i, loss)) pbar.update(pbar_update_interval) pbar.close() if valid_data is not None: return losses, optimizer, valid_losses return losses, optimizer