def most_likely_states(self,
                           data,
                           input=None,
                           transition_mkwargs=None,
                           **memory_kwargs):
        if len(data) == 0:
            return np.array([])
        if isinstance(self.transition,
                      InputDrivenTransition) and input is None:
            raise ValueError("Please provide input.")

        if input is not None:
            input = check_and_convert_to_tensor(input, device=self.device)

        data = check_and_convert_to_tensor(data, device=self.device)
        T = data.shape[0]

        log_pi0 = get_np(self.init_state_distn.log_probs)

        if isinstance(self.transition, StationaryTransition):
            log_Ps = self.transition.log_stationary_transition_matrix
            log_Ps = get_np(log_Ps)  # (K, K)
            log_Ps = log_Ps[None, ]
        else:
            assert isinstance(self.transition,
                              GridTransition), type(self.transition)
            transition_mkwargs = transition_mkwargs if transition_mkwargs else {}
            input = input[:-1] if input else input
            log_Ps = self.transition.log_transition_matrix(
                data[:-1], input, **transition_mkwargs)
            log_Ps = get_np(log_Ps)
            assert log_Ps.shape == (T - 1, self.K, self.K), log_Ps.shape

        log_likes = get_np(self.observation.log_prob(data, **memory_kwargs))
        return viterbi(log_pi0, log_Ps, log_likes)
Exemple #2
0
def plot_grid_transition(n_x, n_y, grid_transition):
    """
    Note: this is for single animal case
    plot the grid transition matrices. return a Figure object
    """
    # TODO: fit for n_x = 1 or n_y = 1
    assert n_x > 1 and n_y > 1
    fig, axn = plt.subplots(n_y, n_x, sharex=True, sharey=True, figsize=(10, 10))

    cbar_ax = fig.add_axes([.93, .3, .03, .4])

    # n_x corresponds to the number of columns, and n_y corresponds to the number of rows.
    grid_idx = 0
    for i in range(n_x):
        for j in range(n_y):
            # plot_idx = ij_to_plot_idx(i, j, n_x, n_y)
            # plt.subplot(n_x, n_y, plot_idx)
            ax = axn[n_y - j - 1][i]
            if i == 0 and j == 0:
                sns.heatmap(get_np(grid_transition[grid_idx]), ax=ax, vmin=0, vmax=1, cmap="BuGn", square=True,
                            cbar_ax=cbar_ax)
            else:
                sns.heatmap(get_np(grid_transition[grid_idx]), ax=ax, vmin=0, vmax=1, cmap="BuGn", square=True,
                            cbar=False)
            ax.tick_params(axis='both', which='both', length=0)
            grid_idx += 1

    plt.tight_layout(rect=[0, 0, .9, 1])
 def plot_grid_transition(self, axn, cbar_ax, n_x, n_y, grid_transition):
     """
     Note: this is for single animal case
     plot the grid transition matrices. return a Figure object
     """
     # n_x corresponds to the number of columns, and n_y corresponds to the number of rows.
     grid_idx = 0
     for i in range(n_x):
         for j in range(n_y):
             # plot_idx = ij_to_plot_idx(i, j, n_x, n_y)
             # plt.subplot(n_x, n_y, plot_idx)
             ax = axn[n_y - j - 1][i]
             ax.clear()
             if i == 0 and j == 0:
                 sns.heatmap(get_np(grid_transition[grid_idx]),
                             ax=ax,
                             vmin=0,
                             vmax=1,
                             cmap="BuGn",
                             square=True,
                             cbar_ax=cbar_ax)
             else:
                 sns.heatmap(get_np(grid_transition[grid_idx]),
                             ax=ax,
                             vmin=0,
                             vmax=1,
                             cmap="BuGn",
                             square=True,
                             cbar=False)
             ax.tick_params(axis='both', which='both', length=0)
             grid_idx += 1
Exemple #4
0
def get_speed_for_single_animal(data, x_grids, y_grids, device=torch.device('cpu')):
    # data: (T, 2)
    _, D = data.shape
    assert D == 2

    if isinstance(data, np.ndarray):
        diff = np.diff(data, axis=0)  # (T-1, 2)
        data = torch.tensor(data, dtype=torch.float64, device=device)
        masks_a = get_masks_for_single_animal(data[:-1], x_grids, y_grids)
    elif isinstance(data, torch.Tensor):
        masks_a = get_masks_for_single_animal(data[:-1], x_grids, y_grids)
        diff = np.diff(get_np(data), axis=0)  # (T-1, 2)
    else:
        raise ValueError("Data must be either np.ndarray or torch.Tensor")

    speed_a_all = np.sqrt(diff[:, 0] ** 2 + diff[:, 1] ** 2)  # (T-1, 2)

    speed_a = []
    G = (len(x_grids) - 1) * (len(y_grids) - 1)
    for g in range(G):
        speed_a_g = speed_a_all[get_np(masks_a[g]) == 1]

        speed_a.append(speed_a_g)

    return speed_a
    def plot_quiver(self,
                    axs,
                    K,
                    scale=1,
                    alpha=1,
                    x_grids=None,
                    y_grids=None,
                    grid_alpha=1):
        #TODO: for single animal now
        # quiver
        XX, YY = np.meshgrid(np.linspace(20, 310, 30), np.linspace(0, 380, 30))
        XYs = np.column_stack(
            (np.ravel(XX), np.ravel(YY)))  # shape (900,2) grid values
        if isinstance(self.observation, GPObservationSingle):
            XY_next, _ = self.observation.get_mu_and_cov_for_single_animal(
                XYs, mu_only=True)
        else:
            XY_next = self.observation.transformation.transform(
                torch.tensor(XYs, dtype=torch.float64, device=self.device))
        dXYs = get_np(XY_next) - XYs[:, None]

        if K == 1:
            axs.clear()
            axs.quiver(XYs[:, 0],
                       XYs[:, 1],
                       dXYs[:, 0, 0],
                       dXYs[:, 0, 1],
                       angles='xy',
                       scale_units='xy',
                       scale=scale,
                       alpha=alpha)
            # add_grid(x_grids, y_grids, grid_alpha=grid_alpha)
            if isinstance(self.observation, GPObservationSingle):
                axs.set_title("K={}, rs={}".format(
                    0, get_np(self.observation.rs[0])))
            else:
                axs.set_title("K={}".format(0))
        else:
            for k in range(K):
                axs[k].clear()
                axs[k].quiver(XYs[:, 0],
                              XYs[:, 1],
                              dXYs[:, k, 0],
                              dXYs[:, k, 1],
                              angles='xy',
                              scale_units='xy',
                              scale=scale,
                              alpha=alpha)
                #add_grid(x_grids, y_grids, grid_alpha=grid_alpha)
                if isinstance(self.observation, GPObservationSingle):
                    axs[k].set_title('K={}, rs={}'.format(
                        k, get_np(self.observation.rs[k])))
                else:
                    axs[k].set_title('K={}'.format(k))
Exemple #6
0
def add_grid_to_ax(ax, x_grids, y_grids):
    if isinstance(x_grids, torch.Tensor):
        x_grids = get_np(x_grids)
    if isinstance(y_grids, torch.Tensor):
        y_grids = get_np(y_grids)

    ax.scatter([x_grids[0], x_grids[0], x_grids[-1], x_grids[-1]], [y_grids[0], y_grids[-1], y_grids[0], y_grids[-1]])
    for j in range(len(y_grids)):
        ax.plot([x_grids[0], x_grids[-1]], [y_grids[j], y_grids[j]], '--', color='grey')

    for i in range(len(x_grids)):
        ax.plot([x_grids[i], x_grids[i]], [y_grids[0], y_grids[-1]], '--', color='grey')
Exemple #7
0
def add_grid(x_grids, y_grids, grid_alpha=1.0):
    if x_grids is None or y_grids is None:
        return
    if isinstance(x_grids, torch.Tensor):
        x_grids = get_np(x_grids)
    if isinstance(y_grids, torch.Tensor):
        y_grids = get_np(y_grids)

    plt.scatter([x_grids[0], x_grids[0], x_grids[-1], x_grids[-1]],
                [y_grids[0], y_grids[1], y_grids[0], y_grids[1]], alpha=grid_alpha)
    for j in range(len(y_grids)):
        plt.plot([x_grids[0], x_grids[-1]], [y_grids[j], y_grids[j]], '--', color='grey', alpha=grid_alpha)

    for i in range(len(x_grids)):
        plt.plot([x_grids[i], x_grids[i]], [y_grids[0], y_grids[-1]], '--', color='grey', alpha=grid_alpha)
    def sample_z(self, T):
        # sample the downsampled_t-invariant markov chain only

        # TODO: may want to expand to other cases
        assert isinstance(self.transition, StationaryTransition), \
            "Sampling the makov chain only supports for stationary transition"

        z = torch.empty(T, dtype=torch.int, device=self.device)
        pi0 = get_np(self.init_state_distn.probs)
        z[0] = npr.choice(self.K, p=pi0)

        P = get_np(self.transition.stationary_transition_matrix)  # (K, K)
        for t in range(1, T):
            z[t] = npr.choice(self.K, p=P[z[t - 1]])
        return z
Exemple #9
0
def plot_dynamics(weighted_corner_vecs, animal, x_grids, y_grids, K, scale=0.1, percentage=None, title=None,
                  grid_alpha=1):
    """
    This is for the illustration of the dynamics of the discrete grid model. Probably want to make the case of K>8
    """
    if isinstance(x_grids, torch.Tensor):
        x_grids = get_np(x_grids)
    if isinstance(y_grids, torch.Tensor):
        y_grids = get_np(y_grids)

    result_corner_vecs = np.sum(weighted_corner_vecs, axis=2)
    n_x = len(x_grids) - 1
    n_y = len(y_grids) - 1
    grid_centers = np.array(
        [[1 / 2 * (x_grids[i] + x_grids[i + 1]), 1 / 2 * (y_grids[j] + y_grids[j + 1])] for i in range(n_x) for j in
         range(n_y)])
    Df = 4

    def plot_dynamics_k(k):
        add_grid(x_grids, y_grids, grid_alpha=grid_alpha)

        add_percentage(k, percentage=percentage, grid_centers=grid_centers)

        for df in range(Df):
            plt.quiver(grid_centers[:, 0], grid_centers[:, 1], weighted_corner_vecs[:, k, df, 0],
                       weighted_corner_vecs[:, k, df, 1],
                       units='xy', scale=scale, width=2, alpha=0.5)

        plt.quiver(grid_centers[:, 0], grid_centers[:, 1], result_corner_vecs[:, k, 0], result_corner_vecs[:, k, 1],
                   units='xy', scale=scale, width=2, color='red', alpha=0.5)
        plt.title("K={}, ".format(k) + animal, fontsize=20)

    if K <= 4:
        plt.figure(figsize=(20, 5))
        if title is not None:
            plt.suptitle(title)
        for k in range(K):
            plt.subplot(1, K, k+1)
            plot_dynamics_k(k)
    elif 4 < K <= 8:
        plt.figure(figsize=(20, 10))
        if title is not None:
            plt.suptitle(title)
        for k in range(K):
            plt.subplot(2, int(K/2)+1, k+1)
            plot_dynamics_k(k)

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
Exemple #10
0
def get_all_angles(data, x_grids, y_grids, device=torch.device('cpu')):

    dXY = data[1:] - data[:-1]
    if isinstance(dXY, torch.Tensor):
        dXY = get_np(dXY)

    return get_all_angles_from_quiver(data[:-1], dXY, x_grids, y_grids, device=device)
Exemple #11
0
    def get_gpt_idx_and_grid_for_single(self, point):
        """

        :param point: (2,)
        :return: gpt_idx a list of length 4
        """
        assert point.shape == (2, )
        find = False

        gpt_idx = []
        grid_idx = 0
        l_y = len(self.y_grids)
        for i in range(len(self.x_grids) - 1):
            for j in range(len(self.y_grids) - 1):
                cond_x = self.x_grids[i] <= point[0] <= self.x_grids[i + 1]
                cond_y = self.y_grids[j] <= point[1] <= self.y_grids[j + 1]
                if cond_x & cond_y:
                    find = True
                    gpt_idx.append(i * l_y + j)  # Q11
                    gpt_idx.append(i * l_y + j + 1)  #  (Q12)
                    gpt_idx.append((i + 1) * l_y + j)  # (Q21)
                    gpt_idx.append((i + 1) * l_y + j + 1)  # Q22
                    break
                grid_idx += 1
            if find:
                break
        if not find:
            raise ValueError("value {} out of the grid world.".format(
                get_np(point)))
        return gpt_idx, grid_idx
    def get_gridpoints_idx_for_single_point(self, point):
        """
        Q12 -- Q22
        |
        Q11 -- Q21
        :param point: (2,)
        :return: idx (4, ) idx for Q11, Q12, Q21, Q22)
        """
        assert point.shape == (2, )
        find = False

        for i in range(self.n_x):
            for j in range(self.n_y):
                cond_x = self.x_grids[i] <= point[0] <= self.x_grids[i + 1]
                cond_y = self.y_grids[j] <= point[1] <= self.y_grids[j + 1]
                if cond_x & cond_y:
                    find = True
                    idx = [
                        i * self.ly + j, i * self.ly + j + 1,
                        (i + 1) * self.ly + j, (i + 1) * self.ly + j + 1
                    ]
                    break
            if find:
                break
        if not find:
            raise ValueError("value {} out of the grid world.".format(
                get_np(point)))
        return idx
    def get_gridpoints_idx_for_single(self, point):
        """

        :param point: (2,)
        :return: idx (n_gps, 4)
        """
        assert point.shape == (2, )
        find = False

        idx = torch.zeros((self.GP, 4), dtype=torch.float64, device=self.device)

        l_y = len(self.y_grids)
        for i in range(len(self.x_grids)-1):
            for j in range(len(self.y_grids)-1):
                cond_x = self.x_grids[i] <= point[0] <= self.x_grids[i+1]
                cond_y = self.y_grids[j] <= point[1] <= self.y_grids[j+1]
                if cond_x & cond_y:
                    find = True
                    idx[i*l_y+j, 0] = 1  # Q11
                    idx[i*l_y+j+1, 1] = 1  # Q12
                    idx[(i+1)*l_y+j, 2] = 1  # Q21
                    idx[(i+1)*l_y+j+1, 3] = 1  # Q22
                    break
            if find:
                break
        if not find:
            raise ValueError("value {} out of the grid world.".format(get_np(point)))
        return idx
Exemple #14
0
def plot_transition(transition):
    """
    plot the heatmap for one transition matrix
    :param transition: (K, K)
    :return:
    """
    sns.heatmap(get_np(transition), vmin=0, vmax=1, cmap="BuGn", square=True)
Exemple #15
0
def plot_space_dist(data, x_grids, y_grids, grid_alpha=1):
    # TODO: there are some
    T, D = data.shape
    assert D == 2 or D == 4
    if isinstance(data, torch.Tensor):
        data = get_np(data)

    n_levels = int(T / 36)

    if D == 4:
        plt.figure(figsize=(15, 7))

        plt.subplot(1, 2, 1)
        sns.kdeplot(data[:, 0], data[:, 1], n_levels=n_levels)
        add_grid(x_grids, y_grids, grid_alpha=grid_alpha)
        plt.title("virgin")

        plt.subplot(1, 2, 2)
        sns.kdeplot(data[:, 2], data[:, 3], n_levels=n_levels)
        add_grid(x_grids, y_grids, grid_alpha=grid_alpha)
        plt.title("mother")

        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    else:
        plt.figure(figsize=(8,7))

        sns.kdeplot(data[:,0], data[:,1], n_levels=n_levels)
        add_grid(x_grids, y_grids)
Exemple #16
0
def get_z_percentage_by_grid(masks_a, z, K, G):
    masks_z_a = np.array([(z[:-1] + 1) * get_np(masks_a[g]) for g in range(G)])

    # (G, K) For each grid g, number of data in that grid = k
    grid_z_a = np.array([[sum(masks_z_a[g] == k) for k in range(1, K + 1)] for g in range(G)])

    grid_z_a_percentage = grid_z_a / (grid_z_a.sum(axis=1)[:, None] + 1e-6)

    return grid_z_a_percentage
    def __init__(self, K, D, x_grids, y_grids, Df, feature_vec_func, tran=None, acc_factor=2, lags=1,
                 use_log_prior=False, no_boundary_prior=False, add_log_diagonal_prior=False, log_prior_sigma_sq=-np.log(1e3),
                 device=torch.device('cpu'), version=1):
        assert lags == 1, "lags should be 1 for lineargrid with_noise."
        super(LinearGridTransformation, self).__init__(K, D)

        self.version = version

        self.d = int(self.D / 2)
        self.device = device

        self.x_grids = check_and_convert_to_tensor(x_grids, dtype=torch.float64, device=self.device)  # [x_0, x_1, ..., x_m]
        self.y_grids = check_and_convert_to_tensor(y_grids, dtype=torch.float64, device=self.device)  # a list [y_0, y_1, ..., y_n]
        self.n_x = len(x_grids) - 1
        self.n_y = len(y_grids) - 1
        # shape: (d, n_gps)
        self.gridpoints = torch.tensor([(x_grid, y_grid) for x_grid in self.x_grids for y_grid in self.y_grids],
                                       device=device)
        self.gridpoints = torch.transpose(self.gridpoints, 0, 1)
        # number of basis grid points
        self.GP = self.gridpoints.shape[1]

        self.Df = Df
        self.feature_vec_func = feature_vec_func

        if tran is not None:
            assert isinstance(tran, LinearGridTransformation)
            self.use_log_prior = tran.use_log_prior
            self.add_log_diagonal_prior = tran.add_log_diagonal_prior
            self.no_boundary_prior = tran.no_boundary_prior
            self.log_prior_sigma_sq = torch.tensor(get_np(tran.log_prior_sigma_sq), dtype=torch.float64,
                                                   device=self.device)
            self.acc_factor = tran.acc_factor
            self.Ws = torch.tensor(get_np(tran.Ws), dtype=torch.float64, requires_grad=True, device=self.device)

        else:
            self.use_log_prior = use_log_prior
            self.add_log_diagonal_prior = add_log_diagonal_prior
            self.no_boundary_prior = no_boundary_prior
            self.log_prior_sigma_sq = torch.tensor(log_prior_sigma_sq, dtype=torch.float64, device=device)

            self.acc_factor = acc_factor
            self.Ws = torch.rand(self.K, 2, self.GP, self.Df, dtype=torch.float64, requires_grad=True,
                                 device=self.device)
    def sample(self, sample_shape=torch.Size()):
        """

        :param sample_shape: currently consider one
        :return: (D, )
        """

        # first, transform to standard normal
        scale = np.exp(get_np(self.log_sigmas))  # (D, )
        loc = get_np(self.mus)  # (D, )
        bb = (get_np(self.bounds[..., 1]) - loc) / scale  # (D, )
        aa = (get_np(self.bounds[..., 0]) - loc) / scale  # (D, )
        if sample_shape == ():
            D = bb.shape[0]
            samples = truncnorm.rvs(a=aa, b=bb, loc=loc, scale=scale, size=(D))

            # some adhoc way to fix the infinity in sample
            samples[samples == -np.inf] = get_np(
                self.bounds[..., 1])[samples == -np.inf]
            samples[samples == np.inf] = get_np(
                self.bounds[..., 0])[samples == np.inf]
        else:
            samples = truncnorm.rvs(a=aa,
                                    b=bb,
                                    loc=loc,
                                    scale=scale,
                                    size=sample_shape)

        samples = torch.tensor(samples, dtype=torch.float64)
        return samples
Exemple #19
0
    def sample_condition_on_zs(self,
                               zs,
                               x0=None,
                               transformation=False,
                               return_np=True,
                               **kwargs):
        """
        Given a z sequence, generate samples condition on this sequence.
        :param zs: (T, )
        :param x0: shape (D,)
        :param return_np: return np.ndarray or torch.tensor
        :return: generated samples (T, D)
        """

        zs = check_and_convert_to_tensor(zs,
                                         dtype=torch.int,
                                         device=self.device)
        T = zs.shape[0]

        assert T > 0
        dtype = torch.float64
        xs = torch.zeros((T, self.D), dtype=dtype)
        if T == 1:
            if x0 is not None:
                print("Nothing to sample")
                return
            else:
                return self.observation.sample_x(zs[0],
                                                 with_noise=transformation)

        if x0 is None:
            x0 = self.observation.sample_x(zs[0],
                                           with_noise=transformation,
                                           return_np=False)
        else:
            x0 = check_and_convert_to_tensor(x0,
                                             dtype=dtype,
                                             device=self.device)
            assert x0.shape == (self.D, )

        xs[0] = x0
        for t in np.arange(1, T):
            x_t = self.observation.sample_x(zs[t],
                                            xihst=xs[:t],
                                            with_noise=transformation,
                                            return_np=False,
                                            **kwargs)
            xs[t] = x_t

        if return_np:
            return get_np(xs)
        return xs
Exemple #20
0
 def plot_transition(self, ax, cbar_ax, transition):
     """
     plot the heatmap for one transition matrix
     :param transition: (K, K)
     :return:
     """
     ax.clear()
     sns.heatmap(get_np(transition),
                 ax=ax,
                 vmin=0,
                 vmax=1,
                 cmap="BuGn",
                 square=True,
                 cbar_ax=cbar_ax)
Exemple #21
0
def get_all_angles_from_quiver(XY, dXY, x_grids, y_grids, device=torch.device('cpu')):
    # XY and dXY should have the same shape
    if isinstance(XY, np.ndarray):
        XY = torch.tensor(XY, dtype=torch.float64, device=device)

    _, D = XY.shape

    if D == 4:
        masks_a, masks_b = get_masks_for_two_animals(XY, x_grids, y_grids)
        masks_a = get_np(masks_a)
        masks_b = get_np(masks_b)

        angles_a = []
        angles_b = []
        G = (len(x_grids) - 1) * (len(y_grids) - 1)
        for g in range(G):
            dXY_a_g = dXY[masks_a[g] == 1][:, 0:2]
            dXY_b_g = dXY[masks_b[g] == 1][:, 2:4]

            angles_a.append(get_angles_single_from_quiver(dXY_a_g))
            angles_b.append(get_angles_single_from_quiver(dXY_b_g))

        return angles_a, angles_b
    elif D == 2:
        masks_a = get_masks_for_single_animal(XY, x_grids, y_grids)
        masks_a = get_np(masks_a)

        angles_a = []
        G = (len(x_grids) - 1) * (len(y_grids) - 1)
        for g in range(G):
            dXY_a_g = dXY[masks_a[g] == 1][:, 0:2]
            angles_a.append(get_angles_single_from_quiver(dXY_a_g))

        return angles_a
    else:
        raise ValueError("Invalid data shape")
Exemple #22
0
def plot_mouse(data, alpha=.8, title=None, xlim=None, ylim=None, mouse='both'):
    if isinstance(data, torch.Tensor):
        data = get_np(data)
    if title is not None:
        plt.title(title)
    _, D = data.shape
    assert D == 4 or D == 2
    if D == 4:
        plt.plot(data[:, 0], data[:, 1], label='virgin', alpha=alpha)
        plt.plot(data[:, 2], data[:, 3], label='mother', alpha=alpha)
    else:
        plt.plot(data[:, 0], data[:, 1], alpha=alpha, label=mouse)
    if xlim is not None:
        plt.xlim(xlim)
    if ylim is not None:
        plt.ylim(ylim)
Exemple #23
0
def plot_2d_time_plot_condition_on_all_zs(data, z, K, title, time_start=None, time_end=None, size=0.5):
    data = get_np(data)

    T, _ = data.shape
    time_start = time_start if time_start else 0
    time_end = time_end if time_end else T

    plt.figure(figsize=(30, 2 * K))
    if title is not None:
        plt.suptitle(title)

    for k in range(K):
        plt.subplot(K, 1, k + 1)
        plot_2d_time_plot_condition_on_z(data, z, k, time_start, time_end, size)

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    def sample_x(self,
                 z,
                 xhist=None,
                 with_noise=False,
                 return_np=True,
                 **memory_kwargs):
        """

        :param z: an integer
        :param xhist: (T_pre, D)
        :param with_noise: return transformed value as sample value, instead of sampling
        :param return_np: boolean, whether return np.ndarray or torch.tensor
        :return: one sample (D, )
        """
        if xhist is None or xhist.shape[0] == 0:

            mu = self.mus_init[z]  # (D,)
            log_sigma = self.log_sigmas_init[z]
        else:
            # sample from the autoregressive distribution

            T_pre = xhist.shape[0]
            if T_pre < self.lags:
                mu = self.transformation.transform_condition_on_z(
                    z, xhist, **memory_kwargs)  # (D, )
            else:
                mu = self.transformation.transform_condition_on_z(
                    z, xhist[-self.lags:], **memory_kwargs)  # (D, )
            assert mu.shape == (self.D, )

            log_sigma = self.log_sigmas[z]

        if with_noise:
            samples = mu
        else:
            dist = TruncatedNormal(mus=mu,
                                   log_sigmas=log_sigma,
                                   bounds=self.bounds)
            samples = dist.sample()

        for d in range(self.D):
            samples[d] = clip(samples[d], self.bounds[d])

        if return_np:
            return get_np(samples)
        return samples
Exemple #25
0
def plot_data_condition_on_all_zs(data, z, K, size=2, alpha=0.3):
    data = get_np(data)

    n_col = 5
    n_row = int(K/n_col)
    if K % n_col > 0:
        n_row += 1

    plt.figure(figsize=(20, 4*n_row))

    title = "spatial occupation under different hidden states"
    if title is not None:
        plt.suptitle(title)

    for k in range(K):
        plt.subplot(n_row, n_col, k+1)
        plot_data_condition_on_zk(data, z, k, size=size, alpha=alpha)
        plt.title('K={} '.format(k))

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    def get_grid_idx_for_single(self, point):
        """

        :param point: (2,)
        :return: grid idx: a scalar
        """
        assert point.shape == (2,), point.shape
        find = False

        grid_idx = 0
        for i in range(len(self.x_grids) - 1):
            for j in range(len(self.y_grids) - 1):
                cond_x = self.x_grids[i] <= point[0] <= self.x_grids[i + 1]
                cond_y = self.y_grids[j] <= point[1] <= self.y_grids[j + 1]
                if cond_x & cond_y:
                    find = True
                    break
                grid_idx += 1
            if find:
                break
        if not find:
            raise ValueError("value {} out of the grid world.".format(get_np(point)))
        return grid_idx
def k_step_prediction_for_lstm_based_model(model,
                                           model_z,
                                           data,
                                           k=0,
                                           feature_vecs=None):
    data = check_and_convert_to_tensor(data)
    T, D = data.shape

    lstm_states = {}

    x_predict_arr = []
    if k == 0:
        if feature_vecs is None:
            print("Did not provide memory information")
            return k_step_prediction(model, model_z, data)
        else:
            feature_vecs_a, feature_vecs_b = feature_vecs

            x_predict = model.observation.sample_x(model_z[0],
                                                   data[:0],
                                                   return_np=True)
            x_predict_arr.append(x_predict)
            for t in range(1, data.shape[0]):
                feature_vec_t = (feature_vecs_a[t - 1:t],
                                 feature_vecs_b[t - 1:t])

                x_predict = model.observation.sample_x(
                    model_z[t],
                    data[:t],
                    return_np=True,
                    with_noise=True,
                    feature_vec=feature_vec_t,
                    lstm_states=lstm_states)
                x_predict_arr.append(x_predict)
    else:
        assert k > 0
        # neglects t = 0 since there is no history

        if T <= k:
            raise ValueError("Please input k such that k < {}.".format(T))

        for t in range(1, T - k + 1):
            # sample k steps forward
            # first step use real value
            z, x = model.sample(1,
                                prefix=(model_z[t - 1:t], data[t - 1:t]),
                                return_np=False,
                                with_noise=True,
                                lstm_states=lstm_states)
            # last k-1 steps use sampled value
            if k >= 1:
                sampled_lstm_states = dict(h_t=lstm_states["h_t"],
                                           c_t=lstm_states["c_t"])
                for i in range(k - 1):
                    z, x = model.sample(1,
                                        prefix=(z, x),
                                        return_np=False,
                                        with_noise=True,
                                        lstm_states=sampled_lstm_states)
            assert x.shape == (1, D)
            x_predict_arr.append(get_np(x[0]))

    x_predict_arr = np.array(x_predict_arr)
    assert x_predict_arr.shape == (T - k, D)
    return x_predict_arr
Exemple #28
0
def main(job_name, cuda_num, downsample_n, filter_traj, gp_version, load_model,
         load_model_dir, load_opt_dir, transition, sticky_alpha, sticky_kappa,
         acc_factor, k, x_grids, y_grids, n_x, n_y, rs_factor, rs, train_rs,
         train_model, pbar_update_interval, video_clips, held_out_proportion,
         torch_seed, np_seed, list_of_num_iters, ckpts_not_to_save, list_of_lr,
         list_of_k_steps, sample_t, quiver_scale):
    if job_name is None:
        raise ValueError("Please provide the job name.")

    cuda_num = int(cuda_num)
    device = torch.device(
        "cuda:{}".format(cuda_num) if torch.cuda.is_available() else "cpu")
    print("Using device {} \n\n".format(device))

    K = k
    sample_T = sample_t
    rs_factor = np.array([float(x) for x in rs_factor.split(",")])
    if rs_factor[0] == 0 and rs_factor[1] == 0:
        rs_factor = None
    rs = float(rs)
    video_clip_start, video_clip_end = [
        float(x) for x in video_clips.split(",")
    ]
    list_of_num_iters = [int(x) for x in list_of_num_iters.split(",")]
    list_of_lr = [float(x) for x in list_of_lr.split(",")]
    list_of_k_steps = [int(x) for x in list_of_k_steps.split(",")]
    assert len(list_of_num_iters) == len(
        list_of_lr
    ), "Length of list_of_num_iters must match length of list_of_lr."
    for lr in list_of_lr:
        if lr > 1:
            raise ValueError("Learning rate should not be larger than 1!")

    ckpts_not_to_save = [int(x) for x in ckpts_not_to_save.split(',')
                         ] if ckpts_not_to_save else []

    repo = git.Repo(
        '.', search_parent_directories=True)  # SocialBehaviorectories=True)
    repo_dir = repo.working_tree_dir  # SocialBehavior

    torch.manual_seed(torch_seed)
    np.random.seed(np_seed)

    ########################## data ########################
    data_dir = repo_dir + '/SocialBehaviorptc/data/trajs_all'
    trajs = joblib.load(data_dir)

    traj = trajs[int(36000 * video_clip_start):int(36000 * video_clip_end)]
    traj = downsample(traj, downsample_n)
    if filter_traj:
        traj = filter_traj_by_speed(traj, q1=0.99, q2=0.99)

    data = torch.tensor(traj, dtype=torch.float64, device=device)
    assert 0 <= held_out_proportion <= 0.4, \
        "held_out-portion should be between 0 and 0.4 (inclusive), but is {}".format(held_out_proportion)
    T = data.shape[0]
    breakpoint = int(T * (1 - held_out_proportion))
    training_data = data[:breakpoint]
    valid_data = data[breakpoint:]

    ######################### model ####################

    # model
    D = 4
    M = 0
    Df = 4

    if load_model:
        print("Loading the model from ", load_model_dir)
        model = joblib.load(load_model_dir)
        tran = model.observation.transformation

        K = model.K

        n_x = len(tran.x_grids) - 1
        n_y = len(tran.y_grids) - 1

        acc_factor = tran.acc_factor

    else:
        print("Creating the model...")
        bounds = np.array([[ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX],
                           [ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX]])

        # grids
        if x_grids is None:
            x_grid_gap = (ARENA_XMAX - ARENA_XMIN) / n_x
            x_grids = np.array(
                [ARENA_XMIN + i * x_grid_gap for i in range(n_x + 1)])
        else:
            x_grids = np.array([float(x) for x in x_grids.split(",")])
            n_x = len(x_grids) - 1

        if y_grids is None:
            y_grid_gap = (ARENA_YMAX - ARENA_YMIN) / n_y
            y_grids = np.array(
                [ARENA_YMIN + i * y_grid_gap for i in range(n_y + 1)])
        else:
            y_grids = np.array([float(x) for x in y_grids.split(",")])
            n_y = len(y_grids) - 1

        if acc_factor is None:
            acc_factor = downsample_n * 10

        tran = GPGridTransformation(K=K,
                                    D=D,
                                    x_grids=x_grids,
                                    y_grids=y_grids,
                                    Df=Df,
                                    feature_vec_func=f_corner_vec_func,
                                    acc_factor=acc_factor,
                                    rs_factor=rs_factor,
                                    rs=None,
                                    train_rs=train_rs,
                                    device=device,
                                    version=gp_version)
        obs = ARTruncatedNormalObservation(K=K,
                                           D=D,
                                           M=M,
                                           lags=1,
                                           bounds=bounds,
                                           transformation=tran,
                                           device=device)

        if transition == 'sticky':
            transition_kwargs = dict(alpha=sticky_alpha, kappa=sticky_kappa)
        else:
            transition_kwargs = None
        model = HMM(K=K,
                    D=D,
                    M=M,
                    transition=transition,
                    observation=obs,
                    transition_kwargs=transition_kwargs,
                    device=device)
        model.observation.mus_init = training_data[0] * torch.ones(
            K, D, dtype=torch.float64, device=device)

    # save experiment params
    exp_params = {
        "job_name": job_name,
        'downsample_n': downsample_n,
        "filter_traj": filter_traj,
        "load_model": load_model,
        "gp_version": gp_version,
        "load_model_dir": load_model_dir,
        "load_opt_dir": load_opt_dir,
        "transition": transition,
        "sticky_alpha": sticky_alpha,
        "sticky_kappa": sticky_kappa,
        "acc_factor": acc_factor,
        "K": K,
        "x_grids": x_grids,
        "y_grids": y_grids,
        "n_x": n_x,
        "n_y": n_y,
        "rs_factor": get_np(tran.rs_factor),
        "rs": rs,
        "train_rs": train_rs,
        "train_model": train_model,
        "pbar_update_interval": pbar_update_interval,
        "video_clip_start": video_clip_start,
        "video_clip_end": video_clip_end,
        "held_out_proportion": held_out_proportion,
        "torch_seed": torch_seed,
        "np_seed": np_seed,
        "list_of_num_iters": list_of_num_iters,
        "list_of_lr": list_of_lr,
        "list_of_k_steps": list_of_k_steps,
        "sample_T": sample_T,
        "quiver_scale": quiver_scale
    }

    print("Experiment params:")
    print(exp_params)

    rslt_dir = addDateTime("rslts/gpgrid/" + job_name)
    rslt_dir = os.path.join(repo_dir, rslt_dir)
    if not os.path.exists(rslt_dir):
        os.makedirs(rslt_dir)
        print("Making result directory...")
    print("Saving to rlst_dir: ", rslt_dir)
    with open(rslt_dir + "/exp_params.json", "w") as f:
        json.dump(exp_params, f, indent=4, cls=NumpyEncoder)

    # compute memory
    print("Computing memory...")

    def get_memory_kwargs(data, train_rs):
        feature_vecs_a = f_corner_vec_func(data[:-1, 0:2])
        feature_vecs_b = f_corner_vec_func(data[:-1, 2:4])

        gpt_idx_a, grid_idx_a = tran.get_gpt_idx_and_grid_idx_for_batch(
            data[:-1, 0:2])
        gpt_idx_b, grid_idx_b = tran.get_gpt_idx_and_grid_idx_for_batch(
            data[:-1, 2:4])

        if train_rs:
            nearby_gpts_a = tran.gridpoints[gpt_idx_a]
            dist_sq_a = (data[:-1, None, 0:2] - nearby_gpts_a)**2

            nearby_gpts_b = tran.gridpoints[gpt_idx_b]
            dist_sq_b = (data[:-1, None, 2:4] - nearby_gpts_b)**2

            return dict(feature_vecs_a=feature_vecs_a,
                        feature_vecs_b=feature_vecs_b,
                        gpt_idx_a=gpt_idx_a,
                        gpt_idx_b=gpt_idx_b,
                        grid_idx_a=grid_idx_a,
                        grid_idx_b=grid_idx_b,
                        dist_sq_a=dist_sq_a,
                        dist_sq_b=dist_sq_b)

        else:
            coeff_a = tran.get_gp_coefficients(data[:-1, 0:2], 0, gpt_idx_a,
                                               grid_idx_a)
            coeff_b = tran.get_gp_coefficients(data[:-1, 2:4], 0, gpt_idx_b,
                                               grid_idx_b)

            return dict(feature_vecs_a=feature_vecs_a,
                        feature_vecs_b=feature_vecs_b,
                        gpt_idx_a=gpt_idx_a,
                        gpt_idx_b=gpt_idx_b,
                        grid_idx_a=grid_idx_a,
                        grid_idx_b=grid_idx_b,
                        coeff_a=coeff_a,
                        coeff_b=coeff_b)

    memory_kwargs = get_memory_kwargs(training_data, train_rs)
    valid_data_memory_kwargs = get_memory_kwargs(valid_data, train_rs)

    log_prob = model.log_likelihood(training_data, **memory_kwargs)

    ##################### training ############################
    if train_model:
        print("start training")
        list_of_losses = []
        if load_opt_dir != "":
            opt = joblib.load(load_opt_dir)
        else:
            opt = None
        for i, (num_iters, lr) in enumerate(zip(list_of_num_iters,
                                                list_of_lr)):
            training_losses, opt, valid_losses = model.fit(
                training_data,
                optimizer=opt,
                method='adam',
                num_iters=num_iters,
                lr=lr,
                pbar_update_interval=pbar_update_interval,
                valid_data=valid_data,
                valid_data_memory_kwargs=valid_data_memory_kwargs,
                **memory_kwargs)
            list_of_losses.append(training_losses)

            checkpoint_dir = rslt_dir + "/checkpoint_{}".format(i)

            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
                print("Creating checkpoint_{} directory...".format(i))
            # save model and opt
            joblib.dump(model, checkpoint_dir + "/model")
            joblib.dump(opt, checkpoint_dir + "/optimizer")

            # save losses
            losses = dict(training_loss=training_losses,
                          valid_loss=valid_losses)
            joblib.dump(losses, checkpoint_dir + "/losses")

            plt.figure()
            plt.plot(training_losses)
            plt.title("training loss")
            plt.savefig(checkpoint_dir + "/training_losses.jpg")
            plt.close()

            plt.figure()
            plt.plot(valid_losses)
            plt.title("validation loss")
            plt.savefig(checkpoint_dir + "/valid_losses.jpg")
            plt.close()

            # save rest
            if i in ckpts_not_to_save:
                print("ckpt {}: skip!\n".format(i))
                continue
            with torch.no_grad():
                rslt_saving(rslt_dir=checkpoint_dir,
                            model=model,
                            data=training_data,
                            memory_kwargs=memory_kwargs,
                            list_of_k_steps=list_of_k_steps,
                            sample_T=sample_T,
                            quiver_scale=quiver_scale,
                            valid_data=valid_data,
                            valid_data_memory_kwargs=valid_data_memory_kwargs,
                            device=device)

    else:
        # only save the results
        rslt_saving(rslt_dir=rslt_dir,
                    model=model,
                    data=training_data,
                    memory_kwargs=memory_kwargs,
                    list_of_k_steps=list_of_k_steps,
                    sample_T=sample_T,
                    quiver_scale=quiver_scale,
                    valid_data=valid_data,
                    valid_data_memory_kwargs=valid_data_memory_kwargs,
                    device=device)

    print("Finish running!")
def rslt_saving(rslt_dir, model, data, mouse, sample_T, train_model, losses,
                quiver_scale, x_grids, y_grids):
    tran = model.observation.transformation

    _, D = data.shape
    assert D == 2 or D == 4, "D must be either 2 or 4."

    n_x = len(x_grids) - 1
    n_y = len(y_grids) - 1

    K = model.K

    #################### inference ###########################

    print("\ninferring most likely states...")
    z = model.most_likely_states(data)

    print("0 step prediction")
    if data.shape[0] <= 5000:
        data_to_predict = data
    else:
        data_to_predict = data[-5000:]
    x_predict = k_step_prediction(model, z, data_to_predict)
    x_predict_err = np.mean(np.abs(x_predict - data_to_predict.numpy()),
                            axis=0)

    print("5 step prediction")
    x_predict_5 = k_step_prediction(model, z, data_to_predict, k=5)
    x_predict_5_err = np.mean(np.abs(x_predict_5 -
                                     data_to_predict[5:].numpy()),
                              axis=0)

    ################### samples #########################

    sample_z, sample_x = model.sample(sample_T)

    center_z = torch.tensor([0], dtype=torch.int)
    if D == 4:
        center_x = torch.tensor([[150, 190, 200, 200]], dtype=torch.float64)
    else:
        center_x = torch.tensor([[150, 190]], dtype=torch.float64)
    sample_z_center, sample_x_center = model.sample(sample_T,
                                                    prefix=(center_z,
                                                            center_x))

    ################## dynamics #####################

    # quiver
    XX, YY = np.meshgrid(np.linspace(20, 310, 30), np.linspace(0, 380, 30))
    XY = np.column_stack(
        (np.ravel(XX), np.ravel(YY)))  # shape (900,2) grid values
    if D == 2:
        XY_grids = XY
    else:
        XY_grids = np.concatenate((XY, XY), axis=1)

    XY_next = tran.transform(torch.tensor(XY_grids, dtype=torch.float64))
    dXY = XY_next.detach().numpy() - XY_grids[:, None]

    #################### saving ##############################

    print("begin saving...")

    # save summary
    avg_transform_speed = np.average(np.abs(dXY), axis=0)
    avg_sample_speed = np.average(np.abs(np.diff(sample_x, axis=0)), axis=0)
    avg_sample_center_speed = np.average(np.abs(
        np.diff(sample_x_center, axis=0)),
                                         axis=0)
    avg_data_speed = np.average(np.abs(np.diff(data.numpy(), axis=0)), axis=0)

    transition_matrix = model.transition.stationary_transition_matrix
    if transition_matrix.requires_grad:
        transition_matrix = transition_matrix.detach().numpy()
    else:
        transition_matrix = transition_matrix.numpy()

    cluster_centers = get_np(tran.mus_loc)

    summary_dict = {
        "init_dist": model.init_dist.detach().numpy(),
        "transition_matrix": transition_matrix,
        "x_predict_err": x_predict_err,
        "x_predict_5_err": x_predict_5_err,
        "mus": cluster_centers,
        "variance": torch.exp(model.observation.log_sigmas).detach().numpy(),
        "log_likes": model.log_likelihood(data).detach().numpy(),
        "avg_transform_speed": avg_transform_speed,
        "avg_data_speed": avg_data_speed,
        "avg_sample_speed": avg_sample_speed,
        "avg_sample_center_speed": avg_sample_center_speed
    }
    with open(rslt_dir + "/summary.json", "w") as f:
        json.dump(summary_dict, f, indent=4, cls=NumpyEncoder)

    # save numbers
    saving_dict = {
        "z": z,
        "x_predict": x_predict,
        "x_predict_5": x_predict_5,
        "sample_z": sample_z,
        "sample_x": sample_x,
        "sample_z_center": sample_z_center,
        "sample_x_center": sample_x_center
    }

    if train_model:
        saving_dict['losses'] = losses
        plt.figure()
        plt.plot(losses)
        plt.savefig(rslt_dir + "/losses.jpg")
        plt.close()
    joblib.dump(saving_dict, rslt_dir + "/numbers")

    # save figures
    plot_z(z, K, title="most likely z for the ground truth")
    plt.savefig(rslt_dir + "/z.jpg")
    plt.close()

    if not os.path.exists(rslt_dir + "/samples"):
        os.makedirs(rslt_dir + "/samples")
        print("Making samples directory...")

    plot_z(sample_z, K, title="sample")
    plt.savefig(rslt_dir + "/samples/sample_z_{}.jpg".format(sample_T))
    plt.close()

    plot_z(sample_z_center, K, title="sample (starting from center)")
    plt.savefig(rslt_dir + "/samples/sample_z_center_{}.jpg".format(sample_T))
    plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(data,
               title="ground truth_{}".format(mouse),
               xlim=[ARENA_XMIN - 20, ARENA_XMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/ground_truth.jpg")
    plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(sample_x,
               title="sample_{}".format(mouse),
               xlim=[ARENA_XMIN - 20, ARENA_XMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/sample_x_{}.jpg".format(sample_T))
    plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(sample_x_center,
               title="sample (starting from center)_{}".format(mouse),
               xlim=[ARENA_XMIN - 20, ARENA_XMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/sample_x_center_{}.jpg".format(sample_T))
    plt.close()

    plot_realdata_quiver(data,
                         z,
                         K,
                         x_grids,
                         y_grids,
                         title="ground truth",
                         cluster_centers=cluster_centers)
    plt.savefig(rslt_dir + "/samples/quiver_ground_truth.jpg", dpi=200)

    plot_realdata_quiver(sample_x,
                         sample_z,
                         K,
                         x_grids,
                         y_grids,
                         title="sample",
                         cluster_centers=cluster_centers)
    plt.savefig(rslt_dir + "/samples/quiver_sample_x_{}.jpg".format(sample_T),
                dpi=200)
    plt.close()

    plot_realdata_quiver(sample_x_center,
                         sample_z_center,
                         K,
                         x_grids,
                         y_grids,
                         title="sample (starting from center)",
                         cluster_centers=cluster_centers)
    plt.savefig(rslt_dir +
                "/samples/quiver_sample_x_center_{}.jpg".format(sample_T),
                dpi=200)
    plt.close()

    # plot mus
    plot_cluster_centers(cluster_centers, x_grids, y_grids)
    plt.savefig(rslt_dir + "/samples/cluster_centers.jpg", dpi=200)

    if not os.path.exists(rslt_dir + "/dynamics"):
        os.makedirs(rslt_dir + "/dynamics")
        print("Making dynamics directory...")

    if D == 2:
        plot_quiver(XY_grids,
                    dXY,
                    mouse,
                    K=K,
                    scale=quiver_scale,
                    alpha=0.9,
                    title="quiver ({})".format(mouse),
                    x_grids=x_grids,
                    y_grids=y_grids,
                    grid_alpha=0.2)
        plt.savefig(rslt_dir + "/dynamics/quiver_{}.jpg".format(mouse),
                    dpi=200)
        plt.close()
    else:
        plot_quiver(XY_grids[:, 0:2],
                    dXY[..., 0:2],
                    'virgin',
                    K=K,
                    scale=quiver_scale,
                    alpha=0.9,
                    title="quiver (virgin)",
                    x_grids=x_grids,
                    y_grids=y_grids,
                    grid_alpha=0.2)
        plt.savefig(rslt_dir + "/dynamics/quiver_a.jpg", dpi=200)
        plt.close()

        plot_quiver(XY_grids[:, 2:4],
                    dXY[..., 2:4],
                    'mother',
                    K=K,
                    scale=quiver_scale,
                    alpha=0.9,
                    title="quiver (mother)",
                    x_grids=x_grids,
                    y_grids=y_grids,
                    grid_alpha=0.2)
        plt.savefig(rslt_dir + "/dynamics/quiver_b.jpg", dpi=200)
        plt.close()

    if not os.path.exists(rslt_dir + "/distributions"):
        os.makedirs(rslt_dir + "/distributions")
        print("Making distributions directory...")

    if D == 4:
        data_angles_a, data_angles_b = get_all_angles(data, x_grids, y_grids)
        sample_angles_a, sample_angles_b = get_all_angles(
            sample_x, x_grids, y_grids)
        sample_x_center_angles_a, sample_x_center_angles_b = get_all_angles(
            sample_x_center, x_grids, y_grids)

        plot_list_of_angles(
            [data_angles_a, sample_angles_a, sample_x_center_angles_a],
            ['data', 'sample', 'sample_c'], "direction distribution (virgin)",
            n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/angles_a.jpg")
        plt.close()
        plot_list_of_angles(
            [data_angles_b, sample_angles_b, sample_x_center_angles_b],
            ['data', 'sample', 'sample_c'], "direction distribution (mother)",
            n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/angles_b.jpg")
        plt.close()

        data_speed_a, data_speed_b = get_speed(data, x_grids, y_grids)
        sample_speed_a, sample_speed_b = get_speed(sample_x, x_grids, y_grids)
        sample_x_center_speed_a, sample_x_center_speed_b = get_speed(
            sample_x_center, x_grids, y_grids)

        plot_list_of_speed(
            [data_speed_a, sample_speed_a, sample_x_center_speed_a],
            ['data', 'sample', 'sample_c'], "speed distribution (virgin)", n_x,
            n_y)
        plt.savefig(rslt_dir + "/distributions/speed_a.jpg")
        plt.close()
        plot_list_of_speed(
            [data_speed_b, sample_speed_b, sample_x_center_speed_b],
            ['data', 'sample', 'sample_c'], "speed distribution (mother)", n_x,
            n_y)
        plt.savefig(rslt_dir + "/distributions/speed_b.jpg")
        plt.close()
    else:
        data_angles_a = get_all_angles(data, x_grids, y_grids)
        sample_angles_a = get_all_angles(sample_x, x_grids, y_grids)
        sample_x_center_angles_a = get_all_angles(sample_x_center, x_grids,
                                                  y_grids)

        plot_list_of_angles(
            [data_angles_a, sample_angles_a, sample_x_center_angles_a],
            ['data', 'sample', 'sample_c'], "direction distribution (virgin)",
            n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/angles_{}.jpg".format(mouse))
        plt.close()

        data_speed_a = get_speed(data, x_grids, y_grids)
        sample_speed_a = get_speed(sample_x, x_grids, y_grids)
        sample_x_center_speed_a = get_speed(sample_x_center, x_grids, y_grids)

        plot_list_of_speed(
            [data_speed_a, sample_speed_a, sample_x_center_speed_a],
            ['data', 'sample', 'sample_c'], "speed distribution (virgin)", n_x,
            n_y)
        plt.savefig(rslt_dir + "/distributions/speed_{}.jpg".format(mouse))
        plt.close()

    try:
        if 100 < data.shape[0] <= 36000:
            plot_space_dist(data, x_grids, y_grids)
        elif data.shape[0] > 36000:
            plot_space_dist(data[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_data.jpg")
        plt.close()

        if 100 < sample_x.shape[0] <= 36000:
            plot_space_dist(sample_x, x_grids, y_grids)
        elif sample_x.shape[0] > 36000:
            plot_space_dist(sample_x[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_sample_x.jpg")
        plt.close()

        if 100 < sample_x_center.shape[0] <= 36000:
            plot_space_dist(sample_x_center, x_grids, y_grids)
        elif sample_x_center.shape[0] > 36000:
            plot_space_dist(sample_x_center[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_sample_x_center.jpg")
        plt.close()
    except:
        print("plot_space_dist unsuccessful")
Exemple #30
0
    def fit(self,
            datas,
            inputs=None,
            optimizer=None,
            method='adam',
            num_iters=1000,
            lr=0.001,
            pbar_update_interval=10,
            valid_data=None,
            transition_memory_kwargs=None,
            valid_data_transition_memory_kwargs=None,
            valid_data_memory_kwargs=None,
            **memory_kwargs):
        pbar = trange(num_iters, file=sys.stdout)

        if optimizer is None:
            if method == 'adam':
                optimizer = torch.optim.Adam(self.trainable_params, lr=lr)
            elif method == 'sgd':
                optimizer = torch.optim.SGD(self.trainable_params, lr=lr)
            else:
                raise ValueError("Method must be chosen from adam and sgd.")
        else:
            assert isinstance(optimizer, (torch.optim.SGD, torch.optim.Adam)), \
                "Optimizer must be chosen from SGD or Adam"
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        losses = []
        if valid_data is not None:
            valid_losses = []
            valid_data_memory_kwargs = valid_data_memory_kwargs if valid_data_memory_kwargs else {}
        for i in np.arange(num_iters):
            optimizer.zero_grad()
            loss = self.loss(datas,
                             inputs,
                             transition_memory_kwargs=transition_memory_kwargs,
                             **memory_kwargs)
            loss.backward()
            optimizer.step()

            loss = get_np(loss)
            losses.append(loss)

            if valid_data is not None:
                if len(valid_data) > 0:
                    with torch.no_grad():
                        valid_losses.append(
                            get_np(
                                self.loss(valid_data,
                                          transition_memory_kwargs=
                                          valid_data_transition_memory_kwargs,
                                          **valid_data_memory_kwargs)))

            if i % pbar_update_interval == 0:
                with nostdout():
                    pbar.set_description('iter {} loss {:.2f}'.format(i, loss))
                    pbar.update(pbar_update_interval)
        pbar.close()

        if valid_data is not None:
            return losses, optimizer, valid_losses
        return losses, optimizer