コード例 #1
0
def k_step_prediction_for_lineargrid_model(model, model_z, data, **kwargs):
    if len(data) == 0:
        return None
    data = check_and_convert_to_tensor(data)
    _, D = data.shape
    assert D == 2 or D == 4
    if D == 4:
        feature_vecs = kwargs.get("feature_vecs", None)
        gridpoints = kwargs.get("gridpoints", None)
        gridpoints_idx = kwargs.get("gridpoints_idx", None)
        if feature_vecs is None or gridpoints_idx is None or gridpoints is None:
            print("Did not provide memory information")
            return k_step_prediction(model, model_z, data)
        else:
            grid_points_idx_a, grid_points_idx_b = gridpoints_idx
            gridpoints_a, gridpoints_b = gridpoints
            feature_vecs_a, feature_vecs_b = feature_vecs

            x_predict_arr = []
            x_predict = model.observation.sample_x(model_z[0],
                                                   data[:0],
                                                   return_np=True,
                                                   with_noise=True)
            x_predict_arr.append(x_predict)
            for t in range(1, data.shape[0]):
                x_predict = model.observation.sample_x(
                    model_z[t],
                    data[t - 1:t],
                    return_np=True,
                    with_noise=True,
                    gridpoints=(gridpoints_a[t - 1], gridpoints_b[t - 1]),
                    gridpoints_idx=(grid_points_idx_a[t - 1],
                                    grid_points_idx_b[t - 1]),
                    feature_vec=(feature_vecs_a[t - 1:t],
                                 feature_vecs_b[t - 1:t]))
                x_predict_arr.append(x_predict)

            x_predict_arr = np.array(x_predict_arr)
    else:
        coeffs = kwargs.get("coeffs", None)
        gridpoints_idx = kwargs.get("gridpoints_idx", None)
        if coeffs is None or gridpoints_idx is None:
            print("Did not provide memory ")
            return k_step_prediction(model, model_z, data)
        else:
            x_predict_arr = []
            x_predict = model.observation.sample_x(model_z[0],
                                                   data[:0],
                                                   return_np=True,
                                                   with_noise=True)
            x_predict_arr.append(x_predict)
            for t in range(1, data.shape[0]):
                x_predict = model.observation.sample_x(
                    model_z[t],
                    data[t - 1:t],
                    return_np=True,
                    with_noise=True,
                    coeffs=coeffs[t - 1:t],
                    gridpoints_idx=gridpoints_idx[t - 1])
    return x_predict_arr
コード例 #2
0
def k_step_prediction_for_lstm_model(model, model_z, data, feature_vecs=None):
    data = check_and_convert_to_tensor(data)

    if feature_vecs is None:
        print("Did not provide memory information")
        return k_step_prediction(model, model_z, data)
    else:
        feature_vecs_a, feature_vecs_b = feature_vecs

        x_predict_arr = []
        x_predict = model.observation.sample_x(model_z[0],
                                               data[:0],
                                               return_np=True)
        x_predict_arr.append(x_predict)
        for t in range(1, data.shape[0]):
            feature_vec_t = (feature_vecs_a[t - 1:t], feature_vecs_b[t - 1:t])

            x_predict = model.observation.sample_x(model_z[t],
                                                   data[:t],
                                                   return_np=True,
                                                   with_noise=True,
                                                   feature_vec=feature_vec_t)
            x_predict_arr.append(x_predict)

        x_predict_arr = np.array(x_predict_arr)
        return x_predict_arr
コード例 #3
0
def k_step_prediction_for_momentum_feature_model(model,
                                                 model_z,
                                                 data,
                                                 momentum_vecs=None,
                                                 features=None):
    data = check_and_convert_to_tensor(data)

    if momentum_vecs is None or features is None:
        return k_step_prediction(model, model_z, data)
    else:
        x_predict_arr = []
        x_predict = model.observation.sample_x(model_z[0],
                                               data[:0],
                                               return_np=True)
        x_predict_arr.append(x_predict)
        for t in range(1, data.shape[0]):
            x_predict = model.observation.sample_x(
                model_z[t],
                data[:t],
                return_np=True,
                with_noise=True,
                momentum_vec=momentum_vecs[t - 1],
                features=(features[0][t - 1], features[1][t - 1]))
            x_predict_arr.append(x_predict)

        x_predict_arr = np.array(x_predict_arr)
        return x_predict_arr
コード例 #4
0
def k_step_prediction_for_gpgrid_model(model, model_z, data, **memory_kwargs):
    data = check_and_convert_to_tensor(data)

    if memory_kwargs == {}:
        print("Did not provide memory information")
        return k_step_prediction(model, model_z, data)
    else:

        feature_vecs_a = memory_kwargs.get("feature_vecs_a", None)
        feature_vecs_b = memory_kwargs.get("feature_vecs_b", None)
        gpt_idx_a = memory_kwargs.get("gpt_idx_a", None)
        gpt_idx_b = memory_kwargs.get("gpt_idx_b", None)
        grid_idx_a = memory_kwargs.get("grid_idx_a", None)
        grid_idx_b = memory_kwargs.get("grid_idx_b")
        coeff_a = memory_kwargs.get("coeff_a", None)
        coeff_b = memory_kwargs.get("coeff_b", None)
        dist_sq_a = memory_kwargs.get("dist_sq_a", None)
        dist_sq_b = memory_kwargs.get("dist_sq_b", None)

        x_predict_arr = []
        x_predict = model.observation.sample_x(model_z[0],
                                               data[:0],
                                               return_np=True)
        x_predict_arr.append(x_predict)
        for t in range(1, data.shape[0]):
            if dist_sq_a is None:
                x_predict = model.observation.sample_x(
                    model_z[t],
                    data[:t],
                    return_np=True,
                    with_noise=True,
                    feature_vec_a=feature_vecs_a[t - 1:t],
                    feature_vec_b=feature_vecs_b[t - 1:t],
                    gpt_idx_a=gpt_idx_a[t - 1:t],
                    gpt_idx_b=gpt_idx_b[t - 1:t],
                    grid_idx_a=grid_idx_a[t - 1:t],
                    grid_idx_b=grid_idx_b[t - 1:t],
                    coeff_a=coeff_a[t - 1:t],
                    coeff_b=coeff_b[t - 1:t])
            else:
                x_predict = model.observation.sample_x(
                    model_z[t],
                    data[:t],
                    return_np=True,
                    with_noise=True,
                    feature_vec_a=feature_vecs_a[t - 1:t],
                    feature_vec_b=feature_vecs_b[t - 1:t],
                    gpt_idx_a=gpt_idx_a[t - 1:t],
                    gpt_idx_b=gpt_idx_b[t - 1:t],
                    grid_idx_a=grid_idx_a[t - 1:t],
                    grid_idx_b=grid_idx_b[t - 1:t],
                    dist_sq_a=dist_sq_a[t - 1:t],
                    dist_sq_b=dist_sq_b[t - 1:t])
            x_predict_arr.append(x_predict)

        x_predict_arr = np.array(x_predict_arr)
        return x_predict_arr
コード例 #5
0
def k_step_prediction_for_grid_model(model, model_z, data, **memory_kwargs):
    if len(data) == 0:
        return None
    data = check_and_convert_to_tensor(data)

    memory_kwargs_a = memory_kwargs.get("memory_kwargs_a", None)
    memory_kwargs_b = memory_kwargs.get("memory_kwargs_b", None)
    if memory_kwargs_a is None or memory_kwargs_b is None:
        print("Did not provide memory information")
        return k_step_prediction(model, model_z, data)
    else:
        momentum_vecs_a = memory_kwargs_a.get("momentum_vecs", None)
        feature_vecs_a = memory_kwargs_a.get("feature_vecs", None)

        momentum_vecs_b = memory_kwargs_b.get("momentum_vecs", None)
        feature_vecs_b = memory_kwargs_b.get("feature_vecs", None)

        x_predict_arr = []
        x_predict = model.observation.sample_x(model_z[0],
                                               data[:0],
                                               return_np=True)
        x_predict_arr.append(x_predict)
        for t in range(1, data.shape[0]):
            if momentum_vecs_a is None:
                m_kwargs_a = dict(feature_vec=feature_vecs_a[t - 1])
                m_kwargs_b = dict(feature_vec=feature_vecs_b[t - 1])
            else:
                m_kwargs_a = dict(momentum_vec=momentum_vecs_a[t - 1],
                                  feature_vec=feature_vecs_a[t - 1])
                m_kwargs_b = dict(momentum_vec=momentum_vecs_b[t - 1],
                                  feature_vec=feature_vecs_b[t - 1])

            x_predict = model.observation.sample_x(model_z[t],
                                                   data[:t],
                                                   return_np=True,
                                                   with_noise=True,
                                                   memory_kwargs_a=m_kwargs_a,
                                                   memory_kwargs_b=m_kwargs_b)
            x_predict_arr.append(x_predict)

        x_predict_arr = np.array(x_predict_arr)
        return x_predict_arr
コード例 #6
0
import matplotlib.pyplot as plt

torch.manual_seed(0)
np.random.seed(0)

# test fitting

K = 3
D = 2
lags= 5

trans1 = LinearTransformation(K=K, D=D, lags=lags)
obs1 = ARGaussianObservation(K=K, D=D, transformation=trans1)
model1 = HMM(K=K,D=D, observation=obs1)

T = 100
sample_z, sample_x = model1.sample(T)

model2 = HMM(K=K, D=D, observation='gaussian', observation_kwargs=dict(lags=lags))

lls, opt = model2.fit(sample_x, num_iters=2000, lr=0.001)

z_infer = model2.most_likely_states(sample_x)

x_predict = k_step_prediction(model2, z_infer, sample_x)

plt.figure()
plt.plot(x_predict[:,0], label='prediction')
plt.plot(sample_x[:,0], label='truth')
plt.show()
コード例 #7
0
def k_step_prediction_for_gpmodel(model, model_z, data, **memory_kwargs):
    data = check_and_convert_to_tensor(data)

    T, D = data.shape
    assert D == 4 or D == 2, D

    K = model.observation.K

    if memory_kwargs == {}:
        print("Did not provide memory information")
        return k_step_prediction(model, model_z, data)
    else:

        # compute As
        if D == 4:
            _, A_a = model.observation.get_gp_cache(data[:-1, 0:2],
                                                    0,
                                                    A_only=True,
                                                    **memory_kwargs)
            _, A_b = model.observation.get_gp_cache(data[:-1, 2:4],
                                                    1,
                                                    A_only=True,
                                                    **memory_kwargs)
            assert A_a.shape == A_b.shape == (
                T - 1, K, 2, model.observation.n_gps * 2), "{}, {}".format(
                    A_a.shape, A_b.shape)

            x_predict_arr = []
            x_predict = model.observation.sample_x(model_z[0],
                                                   data[:0],
                                                   return_np=True)
            x_predict_arr.append(x_predict)
            for t in range(1, data.shape[0]):
                x_predict = model.observation.sample_x(model_z[t],
                                                       data[:t],
                                                       return_np=True,
                                                       with_noise=True,
                                                       A_a=A_a[t - 1:t,
                                                               model_z[t]],
                                                       A_b=A_b[t - 1:t,
                                                               model_z[t]])
                x_predict_arr.append(x_predict)
        else:
            _, A = model.observation.get_gp_cache(data[:-1],
                                                  A_only=True,
                                                  **memory_kwargs)

            x_predict_arr = []
            x_predict = model.observation.sample_x(model_z[0],
                                                   data[:0],
                                                   return_np=True)
            x_predict_arr.append(x_predict)
            for t in range(1, data.shape[0]):
                x_predict = model.observation.sample_x(model_z[t],
                                                       data[:t],
                                                       return_np=True,
                                                       with_noise=True,
                                                       A=(A[0][model_z[t],
                                                               t - 1:t],
                                                          A[1][model_z[t],
                                                               t - 1:t]))
                x_predict_arr.append(x_predict)

        x_predict_arr = np.array(x_predict_arr)
        return x_predict_arr
コード例 #8
0
def rslt_saving(rslt_dir, model, data, animal, memory_kwargs, list_of_k_steps, sample_T,
                quiver_scale, x_grids=None, y_grids=None, dynamics_T=None,
                valid_data=None, valid_data_memory_kwargs=None, device=torch.device('cpu')):

    valid_data_memory_kwargs = valid_data_memory_kwargs if valid_data_memory_kwargs else {}

    tran = model.observation.transformation
    if x_grids is None or y_grids is None:
        x_grids = tran.x_grids
        y_grids = tran.y_grids
    n_x = len(x_grids) - 1
    n_y = len(y_grids) - 1

    K = model.K

    memory_kwargs = memory_kwargs if memory_kwargs else {}


    #################### inference ###########################

    print("\ninferring most likely states...")
    z = model.most_likely_states(data, **memory_kwargs)
    z_valid = model.most_likely_states(valid_data, **valid_data_memory_kwargs)

    # TODO: address valida_data = None
    print("0 step prediction")
    # TODO: add valid data for other model
    if data.shape[0] <= 10000:
        data_to_predict = data
    else:
        data_to_predict = data[-10000:]
    if isinstance(tran, (LinearGridTransformation, SingleLinearGridTransformation)):
        x_predict = k_step_prediction_for_lineargrid_model(model, z, data_to_predict, **memory_kwargs)
        x_predict_valid = k_step_prediction_for_lineargrid_model(model, z_valid, valid_data, **valid_data_memory_kwargs)
    elif isinstance(tran, GPGridTransformation):
        x_predict = k_step_prediction_for_gpgrid_model(model, z, data_to_predict, **memory_kwargs)
        x_predict_valid = k_step_prediction_for_gpgrid_model(model, z_valid, valid_data, **valid_data_memory_kwargs)
    elif isinstance(tran, WeightedGridTransformation):
        x_predict = k_step_prediction_for_weightedgrid_model(model, z, data_to_predict, **memory_kwargs)
        x_predict_valid = k_step_prediction_for_weightedgrid_model(model, z_valid, valid_data, **valid_data_memory_kwargs)
    elif isinstance(tran, (LSTMTransformation, UniLSTMTransformation)):
        x_predict = k_step_prediction_for_lstm_model(model, z, data_to_predict,
                                                     feature_vecs=memory_kwargs["feature_vecs"])
        x_predict_valid = k_step_prediction_for_lstm_model(model, z_valid, valid_data,
                                                           feature_vecs=valid_data_memory_kwargs["feature_vecs"])
    elif isinstance(tran, LSTMBasedTransformation):
        x_predict = k_step_prediction_for_lstm_based_model(model, z, data_to_predict, k=0,
                                                           feature_vecs=memory_kwargs["feature_vecs"])
        x_predict_valid = k_step_prediction_for_lstm_based_model(model, z_valid, valid_data, k=0,
                                                           feature_vecs=valid_data_memory_kwargs["feature_vecs"])
    else:
        raise ValueError("Unsupported transformation!")
    x_predict_err = np.mean(np.abs(x_predict - get_np(data_to_predict)), axis=0)
    if len(valid_data) == 0:
        x_predict_valid_err = None
    else:
        x_predict_valid_err = np.mean(np.abs(x_predict_valid - get_np(valid_data)), axis=0)

    dict_of_x_predict_k = dict(x_predict_0=x_predict, x_predict_v_0=x_predict_valid)
    dict_of_x_predict_k_err = dict(x_predict_0_err=x_predict_err, x_predict_v_0_err=x_predict_valid_err)

    for k_step in list_of_k_steps:
        print("{} step prediction".format(k_step))
        if isinstance(tran, LSTMBasedTransformation):
            # TODO: take care of empty valid data
            x_predict_k = k_step_prediction_for_lstm_based_model(model, z, data_to_predict, k=k_step)
            x_predict_valid_k = k_step_prediction_for_lstm_model(model, z, data_to_predict, k=k_step)
        else:
            x_predict_k = k_step_prediction(model, z, data_to_predict, k=k_step)
            x_predict_valid_k = k_step_prediction(model, z_valid, valid_data, k=k_step)
        x_predict_k_err = np.mean(np.abs(x_predict_k - get_np(data_to_predict[k_step:])), axis=0)
        if len(valid_data) == 0:
            x_predict_valid_k_err = None
        else:
            x_predict_valid_k_err = np.mean(np.abs(x_predict_valid_k - get_np(valid_data[k_step:])), axis=0)
        dict_of_x_predict_k["x_predict_{}".format(k_step)] = x_predict_k
        dict_of_x_predict_k["x_predict_v_{}".format(k_step)] = x_predict_valid_k
        dict_of_x_predict_k_err["x_predict_{}_err".format(k_step)] = x_predict_k_err
        dict_of_x_predict_k_err["x_predict_v_{}_err".format(k_step)] = x_predict_valid_k_err


    ################### samples #########################
    print("sampling")
    center_z = torch.tensor([0], dtype=torch.int, device=device)
    if animal == "both":
        center_x = torch.tensor([[150, 190, 200, 200]], dtype=torch.float64, device=device)
    else:
        center_x = torch.tensor([[150, 190]], dtype=torch.float64, device=device)

    if isinstance(tran, LSTMBasedTransformation):
        lstm_states = {}
        sample_z, sample_x = model.sample(sample_T, lstm_states=lstm_states)

        lstm_states = {}
        sample_z_center, sample_x_center = model.sample(sample_T, prefix=(center_z, center_x), lstm_states=lstm_states)
    else:
        sample_z, sample_x = model.sample(sample_T)

        sample_z_center, sample_x_center = model.sample(sample_T, prefix=(center_z, center_x))


    ################## dynamics #####################

    if isinstance(tran, (LinearGridTransformation, SingleLinearGridTransformation,
                         GPGridTransformation, WeightedGridTransformation)):
        # quiver
        if animal == 'both':
            XX, YY = np.meshgrid(np.linspace(20, 310, 30),
                                 np.linspace(0, 380, 30))
            XY = np.column_stack((np.ravel(XX), np.ravel(YY)))  # shape (900,2) grid values
            XY_grids = np.concatenate((XY, XY), axis=1)  # (900, 4)
        else:
            XX, YY = np.meshgrid(np.linspace(20, 310, 30),
                                 np.linspace(0, 380, 30))
            XY_grids = np.column_stack((np.ravel(XX), np.ravel(YY)))  # shape (900,2) grid values

        XY_next = tran.transform(torch.tensor(XY_grids, dtype=torch.float64, device=device))
        dXY = get_np(XY_next) - XY_grids[:, None]

    # TODO: maybe use sample condition on z (transformation) to show the dynamics
    samples_on_fixed_zs = []
    if isinstance(tran, LSTMBasedTransformation):
        assert dynamics_T is not None
        for k in range(K):
            lstm_states = {}
            fixed_z = torch.ones(dynamics_T, dtype=torch.int) * k
            samples_on_fixed_z = model.sample_condition_on_zs(zs=fixed_z, transformation=True, return_np=True,
                                                              lstm_states=lstm_states)
            samples_on_fixed_zs.append(samples_on_fixed_z)

    #################### saving ##############################

    print("begin saving...")

    # save summary
    if isinstance(tran, (LinearGridTransformation, SingleLinearGridTransformation,
                         GPGridTransformation, WeightedGridTransformation)):
        avg_transform_speed = np.average(np.abs(dXY), axis=0)
    avg_sample_speed = np.average(np.abs(np.diff(sample_x, axis=0)), axis=0)
    avg_sample_center_speed = np.average(np.abs(np.diff(sample_x_center, axis=0)), axis=0)
    avg_data_speed = np.average(np.abs(np.diff(get_np(data), axis=0)), axis=0)

    if isinstance(model.transition, StationaryTransition):
        transition_matrix = model.transition.stationary_transition_matrix
    elif isinstance(model.transition, GridTransition):
        transition_matrix = model.transition.grid_transition_matrix
    else:
        raise ValueError("unsupported transition matrix type: {}".format(type(model.transition)))

    transition_matrix = get_np(transition_matrix)

    summary_dict = {"init_dist": get_np(model.init_dist),
                    "transition_matrix": transition_matrix,
                    "variance": get_np(torch.exp(model.observation.log_sigmas)),
                    "log_likes": get_np(model.log_likelihood(data, **memory_kwargs)),
                    "avg_data_speed": avg_data_speed,
                    "avg_sample_speed": avg_sample_speed, "avg_sample_center_speed": avg_sample_center_speed}
    summary_dict = {**dict_of_x_predict_k_err, **summary_dict}
    if len(valid_data) > 0:
        summary_dict["valid_log_likes"] = get_np(model.log_likelihood(valid_data, **valid_data_memory_kwargs))
    if isinstance(tran, GPGridTransformation):
        summary_dict["real_rs"] = get_np(tran.rs_factor * torch.sigmoid(tran.rs))
    if isinstance(tran, WeightedGridTransformation):
        summary_dict["beta"] = get_np(tran.beta)
    if isinstance(tran, (LinearGridTransformation, GPGridTransformation, WeightedGridTransformation)):
        summary_dict["avg_transform_speed"] = avg_transform_speed
    with open(rslt_dir + "/summary.json", "w") as f:
        json.dump(summary_dict, f, indent=4, cls=NumpyEncoder)

    # save numbers
    saving_dict = {"z": z, "z_valid": z_valid, "sample_z": sample_z, "sample_x": sample_x,
                   "sample_z_center": sample_z_center, "sample_x_center": sample_x_center}
    saving_dict = {**dict_of_x_predict_k, **saving_dict}
    if isinstance(tran, LSTMBasedTransformation):
        saving_dict["samples_on_fixed_zs"] = samples_on_fixed_zs

    joblib.dump(saving_dict, rslt_dir + "/numbers")

    # save figures
    if model.D == 2 and isinstance(model.transition, GridTransition):
        plot_grid_transition(n_x, n_y, model.transition.grid_transition_matrix)
        plt.savefig(rslt_dir + "/grid_transition.jpg")
        plt.close()


    plot_z(z, K, title="most likely z for the ground truth")
    plt.savefig(rslt_dir + "/z.jpg")
    plt.close()

    if len(valid_data) >0:
        plot_z(z_valid, K, title="most likely z for valid data")
        plt.savefig(rslt_dir + "/z_valid.jpg")
        plt.close()

    if not os.path.exists(rslt_dir + "/samples"):
        os.makedirs(rslt_dir + "/samples")
        print("Making samples directory...")

    plot_z(sample_z, K, title="sample")
    plt.savefig(rslt_dir + "/samples/sample_z_{}.jpg".format(sample_T))
    plt.close()

    plot_z(sample_z_center, K, title="sample (starting from center)")
    plt.savefig(rslt_dir + "/samples/sample_z_center_{}.jpg".format(sample_T))
    plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(data, title="ground truth (training)", xlim=[ARENA_XMIN - 20, ARENA_YMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/ground_truth.jpg")
    plt.close()

    if len(valid_data) > 0:
        plt.figure(figsize=(4, 4))
        plot_mouse(valid_data, title="ground truth (valid)", xlim=[ARENA_XMIN - 20, ARENA_YMAX + 20],
                   ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
        plt.legend()
        plt.savefig(rslt_dir + "/samples/ground_truth_valid.jpg")
        plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(sample_x, title="sample", xlim=[ARENA_XMIN - 20, ARENA_YMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/sample_x_{}.jpg".format(sample_T))
    plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(sample_x_center, title="sample (starting from center)", xlim=[ARENA_XMIN - 20, ARENA_YMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/sample_x_center_{}.jpg".format(sample_T))
    plt.close()

    plot_realdata_quiver(data, z, K, x_grids, y_grids, title="ground truth (training)")
    plt.savefig(rslt_dir + "/samples/quiver_ground_truth.jpg", dpi=200)
    plt.close()

    if len(valid_data) > 0:
        plot_realdata_quiver(valid_data, z_valid, K, x_grids, y_grids, title="ground truth (valid)")
        plt.savefig(rslt_dir + "/samples/quiver_ground_truth_valid.jpg", dpi=200)
        plt.close()

    plot_realdata_quiver(sample_x, sample_z, K, x_grids, y_grids, title="sample")
    plt.savefig(rslt_dir + "/samples/quiver_sample_x_{}.jpg".format(sample_T), dpi=200)
    plt.close()

    plot_realdata_quiver(sample_x_center, sample_z_center, K, x_grids, y_grids, title="sample (starting from center)")
    plt.savefig(rslt_dir + "/samples/quiver_sample_x_center_{}.jpg".format(sample_T), dpi=200)
    plt.close()

    if isinstance(tran, (LinearGridTransformation, SingleLinearGridTransformation,
                         GPGridTransformation, WeightedGridTransformation)):
        if not os.path.exists(rslt_dir + "/dynamics"):
            os.makedirs(rslt_dir + "/dynamics")
            print("Making dynamics directory...")

        if animal == 'both':
            plot_quiver(XY_grids[:, 0:2], dXY[..., 0:2], 'virgin', K=K, scale=quiver_scale, alpha=0.9,
                        title="quiver (virgin)", x_grids=x_grids, y_grids=y_grids, grid_alpha=0.2)
            plt.savefig(rslt_dir + "/dynamics/quiver_a.jpg", dpi=200)
            plt.close()

            plot_quiver(XY_grids[:, 2:4], dXY[..., 2:4], 'mother', K=K, scale=quiver_scale, alpha=0.9,
                        title="quiver (mother)", x_grids=x_grids, y_grids=y_grids, grid_alpha=0.2)
            plt.savefig(rslt_dir + "/dynamics/quiver_b.jpg", dpi=200)
            plt.close()
        else:
            plot_quiver(XY_grids, dXY, animal, K=K, scale=quiver_scale, alpha=0.9,
                        title="quiver ({})".format(animal), x_grids=x_grids, y_grids=y_grids, grid_alpha=0.2)
            plt.savefig(rslt_dir + "/dynamics/quiver_{}.jpg".format(animal), dpi=200)
            plt.close()

    elif isinstance(tran, LSTMBasedTransformation):
        if not os.path.exists(rslt_dir + "/dynamics"):
            os.makedirs(rslt_dir + "/dynamics")
            print("Making dynamics directory...")

        for k in range(K):
            plot_realdata_quiver(samples_on_fixed_zs[k], np.ones(dynamics_T, dtype=np.int)*k, K, x_grids=x_grids, y_grids=y_grids,
                                 title="sample conditioned on k={}".format(k))
            plt.savefig(rslt_dir + "/dynamics/samples_on_k{}.jpg".format(k), dpi=200)
            plt.close()

    if not os.path.exists(rslt_dir + "/distributions"):
        os.makedirs(rslt_dir + "/distributions")
        print("Making distributions directory...")

    # sanity checks
    plot_data_condition_on_all_zs(data, z, K, size=2, alpha=0.3)
    plt.savefig(rslt_dir + "/distributions/spatial_occup_groundtruth.jpg", dpi=100)
    plot_data_condition_on_all_zs(sample_x, sample_z, K, size=2, alpha=0.3)
    plt.savefig(rslt_dir + "/distributions/spatial_occup_sample_x.jpg", dpi=100)
    plot_data_condition_on_all_zs(sample_x_center, sample_z_center, K, size=2, alpha=0.3)
    plt.savefig(rslt_dir + "/distributions/spatial_occup_sample_x_center.jpg", dpi=100)

    plot_2d_time_plot_condition_on_all_zs(data, z, K, title='ground truth')
    plt.savefig(rslt_dir + "/distributions/4traces_groundtruth.jpg", dpi=100)
    plot_2d_time_plot_condition_on_all_zs(sample_x, sample_z, K, title='sample_x')
    plt.savefig(rslt_dir + "/distributions/4traces_sample_x.jpg", dpi=100)
    plot_2d_time_plot_condition_on_all_zs(sample_x_center, sample_z_center, K, title='sample_x_center')
    plt.savefig(rslt_dir + "/distributions/4traces_sample_x_center.jpg", dpi=100)

    data_angles = get_all_angles(data, x_grids, y_grids, device=device)
    sample_angles = get_all_angles(sample_x, x_grids, y_grids, device=device)
    sample_x_center_angles = get_all_angles(sample_x_center, x_grids, y_grids,
                                            device=device)

    if animal == 'both':
        plot_list_of_angles([data_angles[0], sample_angles[0], sample_x_center_angles],
                            ['data', 'sample', 'sample_c'], "direction distribution (virgin)", n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/angles_a.jpg")
        plt.close()
        plot_list_of_angles([data_angles[1], sample_angles[1], sample_x_center_angles],
                            ['data', 'sample', 'sample_c'], "direction distribution (mother)", n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/angles_b.jpg")
        plt.close()
    else:
        plot_list_of_angles([data_angles, sample_angles, sample_x_center_angles],
                            ['data', 'sample', 'sample_c'], "direction distribution ({})".format(animal), n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/angles_{}.jpg".format(animal))
        plt.close()

    data_speed = get_speed(data, x_grids, y_grids, device=device)
    sample_speed = get_speed(sample_x, x_grids, y_grids, device=device)
    sample_x_center_speed = get_speed(sample_x_center, x_grids, y_grids, device=device)

    if animal == 'both':
        plot_list_of_speed([data_speed[0], sample_speed[0], sample_x_center_speed[0]],
                           ['data', 'sample', 'sample_c'], "speed distribution (virgin)", n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/speed_a.jpg")
        plt.close()
        plot_list_of_speed([data_speed[1], sample_speed[1], sample_x_center_speed[1]],
                           ['data', 'sample', 'sample_c'], "speed distribution (mother)", n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/speed_b.jpg")
        plt.close()
    else:
        plot_list_of_speed([data_speed, sample_speed, sample_x_center_speed],
                           ['data', 'sample', 'sample_c'], "speed distribution ({})".format(animal), n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/speed_{}.jpg".format(animal))
        plt.close()

    try:
        if 100 < data.shape[0] <= 36000:
            plot_space_dist(data, x_grids, y_grids)
        elif data.shape[0] > 36000:
            plot_space_dist(data[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_data.jpg")
        plt.close()

        if 100 < sample_x.shape[0] <= 36000:
            plot_space_dist(sample_x, x_grids, y_grids)
        elif sample_x.shape[0] > 36000:
            plot_space_dist(sample_x[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_sample_x.jpg")
        plt.close()

        if 100 < sample_x_center.shape[0] <= 36000:
            plot_space_dist(sample_x_center, x_grids, y_grids)
        elif sample_x_center.shape[0] > 36000:
            plot_space_dist(sample_x_center[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_sample_x_center.jpg")
        plt.close()
    except:
        print("plot_space_dist unsuccessful")
コード例 #9
0
print(out)

##################### training ############################

num_iters = 10
losses, opt = model.fit(data,
                        num_iters=num_iters,
                        lr=0.001,
                        momentum_vecs=momentum_vecs,
                        interaction_vecs=interaction_vecs)

##################### sampling ############################
print("start sampling")
sample_z, sample_x = model.sample(30)

#################### inference ###########################
print("inferiring most likely states...")
z = model.most_likely_states(data,
                             momentum_vecs=momentum_vecs,
                             interaction_vecs=interaction_vecs)

print("k step prediction")
x_predict = k_step_prediction_for_momentum_interaction_model(
    model,
    z,
    data,
    momentum_vecs=momentum_vecs,
    interaction_vecs=interaction_vecs)
print("k step prediction without precomputed features.")
x_predict_2 = k_step_prediction(model, z, data, 10)
コード例 #10
0
def rslt_saving(rslt_dir, model, data, animal, memory_kwargs, list_of_k_steps, sample_T,
                quiver_scale, x_grids=None, y_grids=None,
                valid_data=None,
                transition_memory_kwargs=None,
                valid_data_transition_memory_kwargs = None,
                valid_data_memory_kwargs=None, device=torch.device('cpu')):

    transition_memory_kwargs = transition_memory_kwargs if transition_memory_kwargs else {}
    valid_data_transition_memory_kwargs = \
        valid_data_transition_memory_kwargs if valid_data_transition_memory_kwargs else {}
    valid_data_memory_kwargs = valid_data_memory_kwargs if valid_data_memory_kwargs else {}

    obs = model.observation
    if animal == 'both':
        assert isinstance(obs, GPObservation), type(obs)
    else:
        assert isinstance(obs, GPObservationSingle), type(obs)
    if x_grids is None or y_grids is None:
        x_grids = obs.x_grids
        y_grids = obs.y_grids
    n_x = len(x_grids) - 1
    n_y = len(y_grids) - 1

    K = model.K

    memory_kwargs = memory_kwargs if memory_kwargs else {}
    valid_data_memory_kwargs = valid_data_memory_kwargs if valid_data_memory_kwargs else {}

    #################### inference ###########################

    print("\ninferring most likely states...")
    z = model.most_likely_states(data, transition_mkwargs=transition_memory_kwargs, **memory_kwargs)
    z_valid = model.most_likely_states(valid_data, transition_mkwargs=valid_data_transition_memory_kwargs,
                                       **valid_data_memory_kwargs)

    # TODO: address valida_data = None
    print("0 step prediction")
    # TODO: fix kwargs not matching error
    data_to_predict = data

    x_predict = k_step_prediction_for_gpmodel(model, z, data_to_predict, **memory_kwargs)
    x_predict_valid = k_step_prediction_for_gpmodel(model, z_valid, valid_data, **valid_data_memory_kwargs)
    x_predict_err = np.mean(np.abs(x_predict - get_np(data_to_predict)), axis=0)
    if len(valid_data) == 0:
        x_predict_valid_err = None
    else:
        x_predict_valid_err = np.mean(np.abs(x_predict_valid - get_np(valid_data)), axis=0)

    dict_of_x_predict_k = dict(x_predict_0=x_predict, x_predict_v_0=x_predict_valid)
    dict_of_x_predict_k_err = dict(x_predict_0_err=x_predict_err, x_predict_v_0_err=x_predict_valid_err)

    for k_step in list_of_k_steps:
        print("{} step prediction".format(k_step))
        x_predict_k = k_step_prediction(model, z, data_to_predict, k=k_step)
        x_predict_valid_k = k_step_prediction(model, z, data_to_predict, k=k_step)
        x_predict_k_err = np.mean(np.abs(x_predict_k - get_np(data_to_predict[k_step:])), axis=0)
        if len(valid_data) == 0:
            x_predict_valid_k_err = None
        else:
            x_predict_valid_k_err = np.mean(np.abs(x_predict_valid_k - get_np(valid_data[k_step:])), axis=0)
        dict_of_x_predict_k["x_predict_{}".format(k_step)] = x_predict_k
        dict_of_x_predict_k["x_predict_v_{}".format(k_step)] = x_predict_valid_k
        dict_of_x_predict_k_err["x_predict_{}_err".format(k_step)] = x_predict_k_err
        dict_of_x_predict_k_err["x_predict_v_{}_err".format(k_step)] = x_predict_valid_k_err


    ################### samples #########################
    print("sampling")
    center_z = torch.tensor([0], dtype=torch.int, device=device)
    if animal == 'both':
        center_x = torch.tensor([[150, 190, 200, 200]], dtype=torch.float64, device=device)
    else:
        center_x = torch.tensor([[150, 190]], dtype=torch.float64, device=device)

    sample_z, sample_x = model.sample(sample_T)
    sample_z_center, sample_x_center = model.sample(sample_T, prefix=(center_z, center_x))


    ################## dynamics #####################
    print("dynamics")
    # quiver
    XX, YY = np.meshgrid(np.linspace(20, 310, 30),
                         np.linspace(0, 380, 30))
    XY_grids = np.column_stack((np.ravel(XX), np.ravel(YY)))  # shape (900,2) grid values
    if animal == 'both':
        XY_grids = np.concatenate((XY_grids, XY_grids), axis=1)  # (900, 4)

    # TODO, fix not based on z
    if animal == 'both':
        XY_next_a, _ = obs.get_mu_and_cov_for_single_animal(XY_grids[:,0:2], 0, mu_only=True)
        XY_next_b, _ = obs.get_mu_and_cov_for_single_animal(XY_grids[:,2:4], 1, mu_only=True)
        XY_next = torch.cat((XY_next_a, XY_next_b), dim=-1)
        dXY = get_np(XY_next) - XY_grids[:, None]
    else:
        XY_next, _ = obs.get_mu_and_cov_for_single_animal(XY_grids, mu_only=True)
        dXY = get_np(XY_next) - XY_grids[:, None]

    #################### saving ##############################

    print("begin saving...")

    # save summary
    avg_transform_speed = np.average(np.abs(dXY), axis=0)
    avg_sample_speed = np.average(np.abs(np.diff(sample_x, axis=0)), axis=0)
    avg_sample_center_speed = np.average(np.abs(np.diff(sample_x_center, axis=0)), axis=0)
    avg_data_speed = np.average(np.abs(np.diff(get_np(data), axis=0)), axis=0)

    if isinstance(model.transition, StationaryTransition):
        transition_matrix = model.transition.stationary_transition_matrix
    elif isinstance(model.transition, GridTransition):
        transition_matrix = model.transition.grid_transition_matrix
    else:
        raise ValueError("unsupported transition matrix type: {}".format(type(model.transition)))

    transition_matrix = get_np(transition_matrix)

    summary_dict = {"init_dist": get_np(model.init_dist),
                    "transition_matrix": transition_matrix,
                    "variance": get_np(torch.exp(model.observation.log_sigmas)),
                    "log_likes": get_np(model.log_likelihood(data, **memory_kwargs)),
                    "avg_data_speed": avg_data_speed,
                    "avg_sample_speed": avg_sample_speed, "avg_sample_center_speed": avg_sample_center_speed}
    summary_dict = {**dict_of_x_predict_k_err, **summary_dict}
    if len(valid_data) > 0:
        summary_dict["valid_log_likes"] = get_np(model.log_likelihood(valid_data, **valid_data_memory_kwargs))

    summary_dict["rs"] = get_np(model.observation.rs)
    summary_dict["avg_transform_speed"] = avg_transform_speed
    with open(rslt_dir + "/summary.json", "w") as f:
        json.dump(summary_dict, f, indent=4, cls=NumpyEncoder)

    # save numbers
    saving_dict = {"z": z, "z_valid": z_valid, "sample_z": sample_z, "sample_x": sample_x,
                   "sample_z_center": sample_z_center, "sample_x_center": sample_x_center}
    saving_dict = {**dict_of_x_predict_k, **saving_dict}

    joblib.dump(saving_dict, rslt_dir + "/numbers")

    # save figures
    if model.D == 2 and isinstance(model.transition, GridTransition):
        plot_grid_transition(n_x, n_y, model.transition.grid_transition_matrix)
        plt.savefig(rslt_dir + "/grid_transition.jpg")
        plt.close()

    if not os.path.exists(rslt_dir + "/samples"):
        os.makedirs(rslt_dir + "/samples")
        print("Making samples directory...")

    if K > 1:
        plot_z(z, K, title="most likely z for the ground truth")
        plt.savefig(rslt_dir + "/z.jpg")
        plt.close()

        if len(valid_data) > 0:
            plot_z(z_valid, K, title="most likely z for valid data")
            plt.savefig(rslt_dir + "/z_valid.jpg")
            plt.close()

        plot_z(sample_z, K, title="sample")
        plt.savefig(rslt_dir + "/samples/sample_z_{}.jpg".format(sample_T))
        plt.close()

        plot_z(sample_z_center, K, title="sample (starting from center)")
        plt.savefig(rslt_dir + "/samples/sample_z_center_{}.jpg".format(sample_T))
        plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(data, title="ground truth (training)", xlim=[ARENA_XMIN - 20, ARENA_YMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/ground_truth.jpg")
    plt.close()

    if len(valid_data) > 0:
        plt.figure(figsize=(4, 4))
        plot_mouse(valid_data, title="ground truth (valid)", xlim=[ARENA_XMIN - 20, ARENA_YMAX + 20],
                   ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
        plt.legend()
        plt.savefig(rslt_dir + "/samples/ground_truth_valid.jpg")
        plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(sample_x, title="sample", xlim=[ARENA_XMIN - 20, ARENA_YMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/sample_x_{}.jpg".format(sample_T))
    plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(sample_x_center, title="sample (starting from center)", xlim=[ARENA_XMIN - 20, ARENA_YMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/sample_x_center_{}.jpg".format(sample_T))
    plt.close()

    if K > 1:
        plot_realdata_quiver(data, z, K, x_grids, y_grids, title="ground truth (training)")
        plt.savefig(rslt_dir + "/samples/quiver_ground_truth.jpg", dpi=200)
        plt.close()

        if len(valid_data) > 0:
            plot_realdata_quiver(valid_data, z_valid, K, x_grids, y_grids, title="ground truth (valid)")
            plt.savefig(rslt_dir + "/samples/quiver_ground_truth_valid.jpg", dpi=200)
            plt.close()

        plot_realdata_quiver(sample_x, sample_z, K, x_grids, y_grids, title="sample")
        plt.savefig(rslt_dir + "/samples/quiver_sample_x_{}.jpg".format(sample_T), dpi=200)
        plt.close()

        plot_realdata_quiver(sample_x_center, sample_z_center, K, x_grids, y_grids, title="sample (starting from center)")
        plt.savefig(rslt_dir + "/samples/quiver_sample_x_center_{}.jpg".format(sample_T), dpi=200)
        plt.close()

    if not os.path.exists(rslt_dir + "/dynamics"):
        os.makedirs(rslt_dir + "/dynamics")
        print("Making dynamics directory...")

    if animal == 'both':
        plot_quiver(XY_grids[..., 0:2], dXY[..., 0:2], 'virgin', K=K, scale=quiver_scale, alpha=0.9,
                    title="quiver (virgin)", x_grids=x_grids, y_grids=y_grids, grid_alpha=0.2)
        plt.savefig(rslt_dir + "/dynamics/quiver_a.jpg", dpi=200)
        plt.close()

        plot_quiver(XY_grids[:, 2:4], dXY[..., 2:4], 'mother', K=K, scale=quiver_scale, alpha=0.9,
                    title="quiver (mother)", x_grids=x_grids, y_grids=y_grids, grid_alpha=0.2)
        plt.savefig(rslt_dir + "/dynamics/quiver_b.jpg", dpi=200)
        plt.close()
    else:
        plot_quiver(XY_grids, dXY, animal, K=K, scale=quiver_scale, alpha=0.9,
                    title="quiver ({})".format(animal), x_grids=x_grids, y_grids=y_grids, grid_alpha=0.2)
        plt.savefig(rslt_dir + "/dynamics/quiver_{}.jpg".format(animal), dpi=200)
        plt.close()

    if not os.path.exists(rslt_dir + "/distributions"):
        os.makedirs(rslt_dir + "/distributions")
        print("Making distributions directory...")

    # sanity checks
    plot_data_condition_on_all_zs(data, z, K, size=2, alpha=0.3)
    plt.savefig(rslt_dir + "/distributions/spatial_occup_groundtruth.jpg", dpi=100)
    plot_data_condition_on_all_zs(sample_x, sample_z, K, size=2, alpha=0.3)
    plt.savefig(rslt_dir + "/distributions/spatial_occup_sample_x.jpg", dpi=100)
    plot_data_condition_on_all_zs(sample_x_center, sample_z_center, K, size=2, alpha=0.3)
    plt.savefig(rslt_dir + "/distributions/spatial_occup_sample_x_center.jpg", dpi=100)

    plot_2d_time_plot_condition_on_all_zs(data, z, K, title='ground truth')
    plt.savefig(rslt_dir + "/distributions/4traces_groundtruth.jpg", dpi=100)
    plot_2d_time_plot_condition_on_all_zs(sample_x, sample_z, K, title='sample_x')
    plt.savefig(rslt_dir + "/distributions/4traces_sample_x.jpg", dpi=100)
    plot_2d_time_plot_condition_on_all_zs(sample_x_center, sample_z_center, K, title='sample_x_center')
    plt.savefig(rslt_dir + "/distributions/4traces_sample_x_center.jpg", dpi=100)

    data_angles = get_all_angles(data, x_grids, y_grids, device=device)
    sample_angles = get_all_angles(sample_x, x_grids, y_grids, device=device)
    sample_x_center_angles = get_all_angles(sample_x_center, x_grids, y_grids,
                                                                        device=device)

    if animal == 'both':
        plot_list_of_angles([data_angles[0], sample_angles[0], sample_x_center_angles],
                            ['data', 'sample', 'sample_c'], "direction distribution (virgin)", n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/angles_a.jpg")
        plt.close()
        plot_list_of_angles([data_angles[1], sample_angles[1], sample_x_center_angles],
                            ['data', 'sample', 'sample_c'], "direction distribution (mother)", n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/angles_b.jpg")
        plt.close()
    else:
        plot_list_of_angles([data_angles, sample_angles, sample_x_center_angles],
                            ['data', 'sample', 'sample_c'], "direction distribution ({})".format(animal), n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/angles_{}.jpg".format(animal))
        plt.close()

    data_speed = get_speed(data, x_grids, y_grids, device=device)
    sample_speed = get_speed(sample_x, x_grids, y_grids, device=device)
    sample_x_center_speed = get_speed(sample_x_center, x_grids, y_grids, device=device)

    if animal == 'both':
        plot_list_of_speed([data_speed[0], sample_speed[0], sample_x_center_speed[0]],
                           ['data', 'sample', 'sample_c'], "speed distribution (virgin)", n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/speed_a.jpg")
        plt.close()
        plot_list_of_speed([data_speed[1], sample_speed[1], sample_x_center_speed[1]],
                           ['data', 'sample', 'sample_c'], "speed distribution (mother)", n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/speed_b.jpg")
        plt.close()
    else:
        plot_list_of_speed([data_speed, sample_speed, sample_x_center_speed],
                           ['data', 'sample', 'sample_c'], "speed distribution ({})".format(animal), n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/speed_{}.jpg".format(animal))
        plt.close()

    try:
        if 100 < data.shape[0] <= 36000:
            plot_space_dist(data, x_grids, y_grids)
        elif data.shape[0] > 36000:
            plot_space_dist(data[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_data.jpg")
        plt.close()

        if 100 < sample_x.shape[0] <= 36000:
            plot_space_dist(sample_x, x_grids, y_grids)
        elif sample_x.shape[0] > 36000:
            plot_space_dist(sample_x[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_sample_x.jpg")
        plt.close()

        if 100 < sample_x_center.shape[0] <= 36000:
            plot_space_dist(sample_x_center, x_grids, y_grids)
        elif sample_x_center.shape[0] > 36000:
            plot_space_dist(sample_x_center[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_sample_x_center.jpg")
        plt.close()
    except:
        print("plot_space_dist unsuccessful")
コード例 #11
0
def rslt_saving(rslt_dir, model, data, mouse, sample_T, train_model, losses,
                quiver_scale, x_grids, y_grids):
    tran = model.observation.transformation

    _, D = data.shape
    assert D == 2 or D == 4, "D must be either 2 or 4."

    n_x = len(x_grids) - 1
    n_y = len(y_grids) - 1

    K = model.K

    #################### inference ###########################

    print("\ninferring most likely states...")
    z = model.most_likely_states(data)

    print("0 step prediction")
    if data.shape[0] <= 5000:
        data_to_predict = data
    else:
        data_to_predict = data[-5000:]
    x_predict = k_step_prediction(model, z, data_to_predict)
    x_predict_err = np.mean(np.abs(x_predict - data_to_predict.numpy()),
                            axis=0)

    print("5 step prediction")
    x_predict_5 = k_step_prediction(model, z, data_to_predict, k=5)
    x_predict_5_err = np.mean(np.abs(x_predict_5 -
                                     data_to_predict[5:].numpy()),
                              axis=0)

    ################### samples #########################

    sample_z, sample_x = model.sample(sample_T)

    center_z = torch.tensor([0], dtype=torch.int)
    if D == 4:
        center_x = torch.tensor([[150, 190, 200, 200]], dtype=torch.float64)
    else:
        center_x = torch.tensor([[150, 190]], dtype=torch.float64)
    sample_z_center, sample_x_center = model.sample(sample_T,
                                                    prefix=(center_z,
                                                            center_x))

    ################## dynamics #####################

    # quiver
    XX, YY = np.meshgrid(np.linspace(20, 310, 30), np.linspace(0, 380, 30))
    XY = np.column_stack(
        (np.ravel(XX), np.ravel(YY)))  # shape (900,2) grid values
    if D == 2:
        XY_grids = XY
    else:
        XY_grids = np.concatenate((XY, XY), axis=1)

    XY_next = tran.transform(torch.tensor(XY_grids, dtype=torch.float64))
    dXY = XY_next.detach().numpy() - XY_grids[:, None]

    #################### saving ##############################

    print("begin saving...")

    # save summary
    avg_transform_speed = np.average(np.abs(dXY), axis=0)
    avg_sample_speed = np.average(np.abs(np.diff(sample_x, axis=0)), axis=0)
    avg_sample_center_speed = np.average(np.abs(
        np.diff(sample_x_center, axis=0)),
                                         axis=0)
    avg_data_speed = np.average(np.abs(np.diff(data.numpy(), axis=0)), axis=0)

    transition_matrix = model.transition.stationary_transition_matrix
    if transition_matrix.requires_grad:
        transition_matrix = transition_matrix.detach().numpy()
    else:
        transition_matrix = transition_matrix.numpy()

    cluster_centers = get_np(tran.mus_loc)

    summary_dict = {
        "init_dist": model.init_dist.detach().numpy(),
        "transition_matrix": transition_matrix,
        "x_predict_err": x_predict_err,
        "x_predict_5_err": x_predict_5_err,
        "mus": cluster_centers,
        "variance": torch.exp(model.observation.log_sigmas).detach().numpy(),
        "log_likes": model.log_likelihood(data).detach().numpy(),
        "avg_transform_speed": avg_transform_speed,
        "avg_data_speed": avg_data_speed,
        "avg_sample_speed": avg_sample_speed,
        "avg_sample_center_speed": avg_sample_center_speed
    }
    with open(rslt_dir + "/summary.json", "w") as f:
        json.dump(summary_dict, f, indent=4, cls=NumpyEncoder)

    # save numbers
    saving_dict = {
        "z": z,
        "x_predict": x_predict,
        "x_predict_5": x_predict_5,
        "sample_z": sample_z,
        "sample_x": sample_x,
        "sample_z_center": sample_z_center,
        "sample_x_center": sample_x_center
    }

    if train_model:
        saving_dict['losses'] = losses
        plt.figure()
        plt.plot(losses)
        plt.savefig(rslt_dir + "/losses.jpg")
        plt.close()
    joblib.dump(saving_dict, rslt_dir + "/numbers")

    # save figures
    plot_z(z, K, title="most likely z for the ground truth")
    plt.savefig(rslt_dir + "/z.jpg")
    plt.close()

    if not os.path.exists(rslt_dir + "/samples"):
        os.makedirs(rslt_dir + "/samples")
        print("Making samples directory...")

    plot_z(sample_z, K, title="sample")
    plt.savefig(rslt_dir + "/samples/sample_z_{}.jpg".format(sample_T))
    plt.close()

    plot_z(sample_z_center, K, title="sample (starting from center)")
    plt.savefig(rslt_dir + "/samples/sample_z_center_{}.jpg".format(sample_T))
    plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(data,
               title="ground truth_{}".format(mouse),
               xlim=[ARENA_XMIN - 20, ARENA_XMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/ground_truth.jpg")
    plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(sample_x,
               title="sample_{}".format(mouse),
               xlim=[ARENA_XMIN - 20, ARENA_XMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/sample_x_{}.jpg".format(sample_T))
    plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(sample_x_center,
               title="sample (starting from center)_{}".format(mouse),
               xlim=[ARENA_XMIN - 20, ARENA_XMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/sample_x_center_{}.jpg".format(sample_T))
    plt.close()

    plot_realdata_quiver(data,
                         z,
                         K,
                         x_grids,
                         y_grids,
                         title="ground truth",
                         cluster_centers=cluster_centers)
    plt.savefig(rslt_dir + "/samples/quiver_ground_truth.jpg", dpi=200)

    plot_realdata_quiver(sample_x,
                         sample_z,
                         K,
                         x_grids,
                         y_grids,
                         title="sample",
                         cluster_centers=cluster_centers)
    plt.savefig(rslt_dir + "/samples/quiver_sample_x_{}.jpg".format(sample_T),
                dpi=200)
    plt.close()

    plot_realdata_quiver(sample_x_center,
                         sample_z_center,
                         K,
                         x_grids,
                         y_grids,
                         title="sample (starting from center)",
                         cluster_centers=cluster_centers)
    plt.savefig(rslt_dir +
                "/samples/quiver_sample_x_center_{}.jpg".format(sample_T),
                dpi=200)
    plt.close()

    # plot mus
    plot_cluster_centers(cluster_centers, x_grids, y_grids)
    plt.savefig(rslt_dir + "/samples/cluster_centers.jpg", dpi=200)

    if not os.path.exists(rslt_dir + "/dynamics"):
        os.makedirs(rslt_dir + "/dynamics")
        print("Making dynamics directory...")

    if D == 2:
        plot_quiver(XY_grids,
                    dXY,
                    mouse,
                    K=K,
                    scale=quiver_scale,
                    alpha=0.9,
                    title="quiver ({})".format(mouse),
                    x_grids=x_grids,
                    y_grids=y_grids,
                    grid_alpha=0.2)
        plt.savefig(rslt_dir + "/dynamics/quiver_{}.jpg".format(mouse),
                    dpi=200)
        plt.close()
    else:
        plot_quiver(XY_grids[:, 0:2],
                    dXY[..., 0:2],
                    'virgin',
                    K=K,
                    scale=quiver_scale,
                    alpha=0.9,
                    title="quiver (virgin)",
                    x_grids=x_grids,
                    y_grids=y_grids,
                    grid_alpha=0.2)
        plt.savefig(rslt_dir + "/dynamics/quiver_a.jpg", dpi=200)
        plt.close()

        plot_quiver(XY_grids[:, 2:4],
                    dXY[..., 2:4],
                    'mother',
                    K=K,
                    scale=quiver_scale,
                    alpha=0.9,
                    title="quiver (mother)",
                    x_grids=x_grids,
                    y_grids=y_grids,
                    grid_alpha=0.2)
        plt.savefig(rslt_dir + "/dynamics/quiver_b.jpg", dpi=200)
        plt.close()

    if not os.path.exists(rslt_dir + "/distributions"):
        os.makedirs(rslt_dir + "/distributions")
        print("Making distributions directory...")

    if D == 4:
        data_angles_a, data_angles_b = get_all_angles(data, x_grids, y_grids)
        sample_angles_a, sample_angles_b = get_all_angles(
            sample_x, x_grids, y_grids)
        sample_x_center_angles_a, sample_x_center_angles_b = get_all_angles(
            sample_x_center, x_grids, y_grids)

        plot_list_of_angles(
            [data_angles_a, sample_angles_a, sample_x_center_angles_a],
            ['data', 'sample', 'sample_c'], "direction distribution (virgin)",
            n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/angles_a.jpg")
        plt.close()
        plot_list_of_angles(
            [data_angles_b, sample_angles_b, sample_x_center_angles_b],
            ['data', 'sample', 'sample_c'], "direction distribution (mother)",
            n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/angles_b.jpg")
        plt.close()

        data_speed_a, data_speed_b = get_speed(data, x_grids, y_grids)
        sample_speed_a, sample_speed_b = get_speed(sample_x, x_grids, y_grids)
        sample_x_center_speed_a, sample_x_center_speed_b = get_speed(
            sample_x_center, x_grids, y_grids)

        plot_list_of_speed(
            [data_speed_a, sample_speed_a, sample_x_center_speed_a],
            ['data', 'sample', 'sample_c'], "speed distribution (virgin)", n_x,
            n_y)
        plt.savefig(rslt_dir + "/distributions/speed_a.jpg")
        plt.close()
        plot_list_of_speed(
            [data_speed_b, sample_speed_b, sample_x_center_speed_b],
            ['data', 'sample', 'sample_c'], "speed distribution (mother)", n_x,
            n_y)
        plt.savefig(rslt_dir + "/distributions/speed_b.jpg")
        plt.close()
    else:
        data_angles_a = get_all_angles(data, x_grids, y_grids)
        sample_angles_a = get_all_angles(sample_x, x_grids, y_grids)
        sample_x_center_angles_a = get_all_angles(sample_x_center, x_grids,
                                                  y_grids)

        plot_list_of_angles(
            [data_angles_a, sample_angles_a, sample_x_center_angles_a],
            ['data', 'sample', 'sample_c'], "direction distribution (virgin)",
            n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/angles_{}.jpg".format(mouse))
        plt.close()

        data_speed_a = get_speed(data, x_grids, y_grids)
        sample_speed_a = get_speed(sample_x, x_grids, y_grids)
        sample_x_center_speed_a = get_speed(sample_x_center, x_grids, y_grids)

        plot_list_of_speed(
            [data_speed_a, sample_speed_a, sample_x_center_speed_a],
            ['data', 'sample', 'sample_c'], "speed distribution (virgin)", n_x,
            n_y)
        plt.savefig(rslt_dir + "/distributions/speed_{}.jpg".format(mouse))
        plt.close()

    try:
        if 100 < data.shape[0] <= 36000:
            plot_space_dist(data, x_grids, y_grids)
        elif data.shape[0] > 36000:
            plot_space_dist(data[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_data.jpg")
        plt.close()

        if 100 < sample_x.shape[0] <= 36000:
            plot_space_dist(sample_x, x_grids, y_grids)
        elif sample_x.shape[0] > 36000:
            plot_space_dist(sample_x[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_sample_x.jpg")
        plt.close()

        if 100 < sample_x_center.shape[0] <= 36000:
            plot_space_dist(sample_x_center, x_grids, y_grids)
        elif sample_x_center.shape[0] > 36000:
            plot_space_dist(sample_x_center[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_sample_x_center.jpg")
        plt.close()
    except:
        print("plot_space_dist unsuccessful")
コード例 #12
0
def test_model():
    torch.manual_seed(0)
    np.random.seed(0)

    T = 100
    D = 4

    # data = np.array([[1.0, 1.0, 1.0, 6.0], [3.0, 6.0, 8.0, 6.0],
    #                [4.0, 7.0, 8.0, 5.0], [6.0, 7.0, 5.0, 6.0], [8.0, 2.0, 6.0, 1.0]])
    data = np.random.randn(T, D)
    data = torch.tensor(data, dtype=torch.float64)

    xmax = max(np.max(data[:, 0].numpy()), np.max(data[:, 2].numpy()))
    xmin = min(np.min(data[:, 0].numpy()), np.min(data[:, 2].numpy()))
    ymax = max(np.max(data[:, 1].numpy()), np.max(data[:, 3].numpy()))
    ymin = min(np.min(data[:, 1].numpy()), np.min(data[:, 3].numpy()))
    bounds = np.array([[xmin - 1, xmax + 1], [ymin - 1, ymax + 1],
                       [xmin - 1, xmax + 1], [ymin - 1, ymax + 1]])

    def toy_feature_vec_func(s):
        """
        :param s: self, (T, 2)
        :param o: other, (T, 2)
        :return: features, (T, Df, 2)
        """
        corners = torch.tensor([[0, 0], [0, 8], [10, 0], [10, 8]],
                               dtype=torch.float64)
        return feature_direction_vec(s, corners)

    K = 3

    Df = 4
    lags = 1

    tran = UniLSTMTransformation(K=K,
                                 D=D,
                                 Df=Df,
                                 feature_vec_func=toy_feature_vec_func,
                                 lags=lags,
                                 dh=10)

    # observation
    obs = ARTruncatedNormalObservation(K=K,
                                       D=D,
                                       lags=lags,
                                       bounds=bounds,
                                       transformation=tran)

    # model
    model = HMM(K=K, D=D, observation=obs)

    print("calculating log likelihood")
    feature_vecs_a = toy_feature_vec_func(data[:-1, 0:2])
    feature_vecs_b = toy_feature_vec_func(data[:-1, 2:4])
    feature_vecs = (feature_vecs_a, feature_vecs_b)
    packed_data = get_packed_data((data[:-1]), lags=lags)

    model.log_likelihood(data,
                         feature_vecs=feature_vecs,
                         packed_data=packed_data)

    # fit
    losses, _ = model.fit(data,
                          optimizer=None,
                          method="adam",
                          num_iters=50,
                          feature_vecs=feature_vecs,
                          packed_data=packed_data)

    plt.figure()
    plt.plot(losses)
    plt.show()

    # most-likely-z
    print("Most likely z...")
    z = model.most_likely_states(data,
                                 feature_vecs=feature_vecs,
                                 packed_data=packed_data)

    # prediction

    if data.shape[0] <= 1000:
        data_to_predict = data
    else:
        data_to_predict = data[-1000:]

    print("0 step prediction")
    if data.shape[0] <= 1000:
        data_to_predict = data
    else:
        data_to_predict = data[-1000:]
    x_predict = k_step_prediction_for_lstm_model(model,
                                                 z,
                                                 data_to_predict,
                                                 feature_vecs=feature_vecs)
    x_predict_err = np.mean(np.abs(x_predict - data_to_predict.numpy()),
                            axis=0)

    print("10 step prediction")
    x_predict_2 = k_step_prediction(model, z, data_to_predict, k=10)
    x_predict_2_err = np.mean(np.abs(x_predict_2 -
                                     data_to_predict[10:].numpy()),
                              axis=0)

    # samples
    print("sampling...")
    sample_T = 5
    sample_z, sample_x = model.sample(sample_T)
コード例 #13
0
def k_step_prediction_for_lstm_based_model(model,
                                           model_z,
                                           data,
                                           k=0,
                                           feature_vecs=None):
    data = check_and_convert_to_tensor(data)
    T, D = data.shape

    lstm_states = {}

    x_predict_arr = []
    if k == 0:
        if feature_vecs is None:
            print("Did not provide memory information")
            return k_step_prediction(model, model_z, data)
        else:
            feature_vecs_a, feature_vecs_b = feature_vecs

            x_predict = model.observation.sample_x(model_z[0],
                                                   data[:0],
                                                   return_np=True)
            x_predict_arr.append(x_predict)
            for t in range(1, data.shape[0]):
                feature_vec_t = (feature_vecs_a[t - 1:t],
                                 feature_vecs_b[t - 1:t])

                x_predict = model.observation.sample_x(
                    model_z[t],
                    data[:t],
                    return_np=True,
                    with_noise=True,
                    feature_vec=feature_vec_t,
                    lstm_states=lstm_states)
                x_predict_arr.append(x_predict)
    else:
        assert k > 0
        # neglects t = 0 since there is no history

        if T <= k:
            raise ValueError("Please input k such that k < {}.".format(T))

        for t in range(1, T - k + 1):
            # sample k steps forward
            # first step use real value
            z, x = model.sample(1,
                                prefix=(model_z[t - 1:t], data[t - 1:t]),
                                return_np=False,
                                with_noise=True,
                                lstm_states=lstm_states)
            # last k-1 steps use sampled value
            if k >= 1:
                sampled_lstm_states = dict(h_t=lstm_states["h_t"],
                                           c_t=lstm_states["c_t"])
                for i in range(k - 1):
                    z, x = model.sample(1,
                                        prefix=(z, x),
                                        return_np=False,
                                        with_noise=True,
                                        lstm_states=sampled_lstm_states)
            assert x.shape == (1, D)
            x_predict_arr.append(get_np(x[0]))

    x_predict_arr = np.array(x_predict_arr)
    assert x_predict_arr.shape == (T - k, D)
    return x_predict_arr
コード例 #14
0
def test_model():
    torch.manual_seed(0)
    np.random.seed(0)

    T = 5
    x_grids = np.array([0.0, 5.0, 10.0])
    y_grids = np.array([0.0, 4.0, 8.0])

    data = np.array([[1.0, 1.0, 1.0, 6.0], [3.0, 6.0, 8.0, 6.0],
                     [4.0, 7.0, 8.0, 5.0], [6.0, 7.0, 5.0, 6.0],
                     [8.0, 2.0, 6.0, 1.0]])
    data = torch.tensor(data, dtype=torch.float64)

    def toy_feature_vec_func(s):
        """
        :param s: self, (T, 2)
        :param o: other, (T, 2)
        :return: features, (T, Df, 2)
        """
        corners = torch.tensor([[0, 0], [0, 8], [10, 0], [10, 8]],
                               dtype=torch.float64)
        return feature_direction_vec(s, corners)

    K = 3
    D = 4
    M = 0

    Df = 4

    bounds = np.array([[0.0, 10.0], [0.0, 8.0], [0.0, 10.0], [0.0, 8.0]])
    tran = LinearGridTransformation(K=K,
                                    D=D,
                                    x_grids=x_grids,
                                    y_grids=y_grids,
                                    Df=Df,
                                    feature_vec_func=toy_feature_vec_func)
    obs = ARTruncatedNormalObservation(K=K,
                                       D=D,
                                       M=0,
                                       lags=1,
                                       bounds=bounds,
                                       transformation=tran)

    model = HMM(K=K, D=D, M=M, transition="stationary", observation=obs)
    model.observation.mus_init = data[0] * torch.ones(
        K, D, dtype=torch.float64)

    # calculate memory
    gridpoints_idx_a = tran.get_gridpoints_idx_for_batch(data[:-1, 0:2])
    gridpoints_idx_b = tran.get_gridpoints_idx_for_batch(data[:-1, 2:4])
    gridpoints_a = tran.get_gridpoints_for_batch(gridpoints_idx_a)
    gridpoints_b = tran.get_gridpoints_for_batch(gridpoints_idx_b)
    feature_vecs_a = toy_feature_vec_func(data[:-1, 0:2])
    feature_vecs_b = toy_feature_vec_func(data[:-1, 2:4])

    gridpoints_idx = (gridpoints_idx_a, gridpoints_idx_b)
    gridpoints = (gridpoints_a, gridpoints_b)
    feature_vecs = (feature_vecs_a, feature_vecs_b)

    # fit
    losses, opt = model.fit(data,
                            optimizer=None,
                            method='adam',
                            num_iters=100,
                            lr=0.01,
                            pbar_update_interval=10,
                            gridpoints=gridpoints,
                            gridpoints_idx=gridpoints_idx,
                            feature_vecs=feature_vecs)

    plt.figure()
    plt.plot(losses)
    plt.show()

    # most-likely-z
    print("Most likely z...")
    z = model.most_likely_states(data,
                                 gridpoints_idx=gridpoints_idx,
                                 feature_vecs=feature_vecs)

    # prediction
    print("0 step prediction")
    if data.shape[0] <= 1000:
        data_to_predict = data
    else:
        data_to_predict = data[-1000:]
    x_predict = k_step_prediction_for_lineargrid_model(
        model,
        z,
        data_to_predict,
        gridpoints_idx=gridpoints_idx,
        feature_vecs=feature_vecs)
    x_predict_err = np.mean(np.abs(x_predict - data_to_predict.numpy()),
                            axis=0)

    print("2 step prediction")
    x_predict_2 = k_step_prediction(model, z, data_to_predict, k=2)
    x_predict_2_err = np.mean(np.abs(x_predict_2 -
                                     data_to_predict[2:].numpy()),
                              axis=0)

    # samples
    sample_T = 5
    sample_z, sample_x = model.sample(sample_T)
コード例 #15
0
def test_model():
    torch.manual_seed(0)
    np.random.seed(0)

    T = 5
    x_grids = np.array([0.0, 10.0])
    y_grids = np.array([0.0, 8.0])

    bounds = np.array([[0.0, 10.0], [0.0, 8.0], [0.0, 10.0], [0.0, 8.0]])
    data = np.array([[1.0, 1.0, 1.0, 6.0], [3.0, 6.0, 8.0, 6.0],
                     [4.0, 7.0, 8.0, 5.0], [6.0, 7.0, 5.0, 6.0],
                     [8.0, 2.0, 6.0, 1.0]])
    data = torch.tensor(data, dtype=torch.float64)

    K = 3
    D = 4
    M = 0

    obs = GPObservation(K=K,
                        D=D,
                        x_grids=x_grids,
                        y_grids=y_grids,
                        bounds=bounds,
                        train_rs=True)

    correct_kerneldist_gg = torch.tensor(
        [[0., 0., 64., 64., 100., 100., 164., 164.],
         [0., 0., 64., 64., 100., 100., 164., 164.],
         [64., 64., 0., 0., 164., 164., 100., 100.],
         [64., 64., 0., 0., 164., 164., 100., 100.],
         [100., 100., 164., 164., 0., 0., 64., 64.],
         [100., 100., 164., 164., 0., 0., 64., 64.],
         [164., 164., 100., 100., 64., 64., 0., 0.],
         [164., 164., 100., 100., 64., 64., 0., 0.]],
        dtype=torch.float64)
    assert torch.all(torch.eq(correct_kerneldist_gg,
                              obs.kernel_distsq_gg)), obs.kernel_distsq_gg

    log_prob_nocache = obs.log_prob(data)
    print("log_prob_nocache = {}".format(log_prob_nocache))

    kernel_distsq_xg_a = kernel_distsq_doubled(data[:-1, 0:2],
                                               obs.inducing_points)
    kernel_distsq_xg_b = kernel_distsq_doubled(data[:-1, 2:4],
                                               obs.inducing_points)

    correct_kernel_distsq_xg_a = torch.tensor(
        [[2., 2., 50., 50., 82., 82., 130., 130.],
         [2., 2., 50., 50., 82., 82., 130., 130.],
         [45., 45., 13., 13., 85., 85., 53., 53.],
         [45., 45., 13., 13., 85., 85., 53., 53.],
         [65., 65., 17., 17., 85., 85., 37., 37.],
         [65., 65., 17., 17., 85., 85., 37., 37.],
         [85., 85., 37., 37., 65., 65., 17., 17.],
         [85., 85., 37., 37., 65., 65., 17., 17.]],
        dtype=torch.float64)
    assert torch.all(torch.eq(correct_kernel_distsq_xg_a,
                              kernel_distsq_xg_a)), kernel_distsq_xg_a

    memory_kwargs = dict(kernel_distsq_xg_a=kernel_distsq_xg_a,
                         kernel_distsq_xg_b=kernel_distsq_xg_b)

    log_prob = obs.log_prob(data, **memory_kwargs)
    print("log_prob = {}".format(log_prob))

    assert torch.all(torch.eq(log_prob_nocache, log_prob))

    Sigma_a, A_a = obs.get_gp_cache(data[:-1, 0:2], 0, **memory_kwargs)
    Sigma_b, A_b = obs.get_gp_cache(data[:-1, 2:4], 1, **memory_kwargs)
    memory_kwargs_2 = dict(Sigma_a=Sigma_a, A_a=A_a, Sigma_b=Sigma_b, A_b=A_b)

    print("calculating log prob 2...")
    log_prob2 = obs.log_prob(data, **memory_kwargs_2)
    assert torch.all(torch.eq(log_prob, log_prob2))

    model = HMM(K=K, D=D, M=M, transition="stationary", observation=obs)
    model.observation.mus_init = data[0] * torch.ones(
        K, D, dtype=torch.float64)

    # fit
    losses, opt = model.fit(data,
                            optimizer=None,
                            method='adam',
                            num_iters=100,
                            lr=0.01,
                            pbar_update_interval=10,
                            **memory_kwargs)

    plt.figure()
    plt.plot(losses)
    plt.show()

    # most-likely-z
    print("Most likely z...")
    z = model.most_likely_states(data, **memory_kwargs)

    # prediction
    print("0 step prediction")
    if data.shape[0] <= 1000:
        data_to_predict = data
    else:
        data_to_predict = data[-1000:]
    x_predict = k_step_prediction_for_gpmodel(model, z, data_to_predict,
                                              **memory_kwargs)
    x_predict_err = np.mean(np.abs(x_predict - data_to_predict.numpy()),
                            axis=0)

    print("2 step prediction")
    x_predict_2 = k_step_prediction(model, z, data_to_predict, k=2)
    x_predict_2_err = np.mean(np.abs(x_predict_2 -
                                     data_to_predict[2:].numpy()),
                              axis=0)

    # samples
    sample_T = 5
    sample_z, sample_x = model.sample(sample_T)
コード例 #16
0
joblib.dump(losses, "losses")
plt.plot(losses)
# sampling
print("start sampling")
sample_z, sample_x = model.sample(T)
plot_realdata_quiver(sample_x, scale=1, title="sample")
plt.show()
joblib.dump((sample_z, sample_x), "samples")

# inference
print("inferiring most likely states...")
z = model.most_likely_states(data,
                             memory_kwargs_a=m_kwargs_a,
                             memory_kwargs_b=m_kwargs_b)
joblib.dump(z, "z")

data_to_predict = data[-1000:]
print("0 step prediction")
x_predict = k_step_prediction_for_grid_model(model,
                                             z,
                                             data_to_predict,
                                             memory_kwargs_a=m_kwargs_a,
                                             memory_kwargs_b=m_kwargs_b)
err = np.mean(np.abs(x_predict - data_to_predict.numpy()), axis=0)
print(err)

print("k step prediction")
x_predict_5 = k_step_prediction(model, z, data_to_predict, 5)
err = np.mean(np.abs(x_predict_5 - data_to_predict[5:].numpy()), axis=0)
print(err)
コード例 #17
0
# model
model = HMM(K=K, D=D, M=M, observation=obs)

# log like
log_prob = model.log_probability(data)
print("log probability = ", log_prob)

# training
print("start training...")
num_iters = 10
losses, opt = model.fit(data, num_iters=num_iters, lr=0.001)

# sampling
samplt_T = T
print("start sampling")
sample_z, sample_x = model.sample(samplt_T)

print("start sampling based on with_noise")
sample_z2, sample_x2 = model.sample(T, with_noise=True)

# inference
print("inferiring most likely states...")
z = model.most_likely_states(data)

print("0 step prediction")
x_predict = k_step_prediction(model, z, data)

print("k step prediction")
x_predict_10 = k_step_prediction(model, z, data, 2)
コード例 #18
0
losses = []
for i in np.arange(num_iters):

    optimizer.zero_grad()

    loss = model.loss(data)
    loss.backward(retain_graph=True)
    optimizer.step()

    loss = loss.detach().numpy()
    losses.append(loss)

    if i % 10 == 0:
        pbar.set_description('iter {} loss {:.2f}'.format(i, loss))
        pbar.update(10)

# check reconstruction
x_reconstruct = model.sample_condition_on_zs(z, data[0])

# infer the latent states
infer_z = model.most_likely_states(data)

perm = find_permutation(z.numpy(), infer_z, K1=K, K2=K)

model.permute(perm)
hmm_z = model.most_likely_states(data)

# check prediction
x_predict_cond_z = k_step_prediction(model, z, data)
コード例 #19
0
lags = 1

bounds = np.array([[-2, 2], [0, 1], [-2, 2], [0, 1]])

As = np.array([
    np.column_stack([np.identity(D),
                     np.zeros((D, (lags - 1) * D))]) for _ in range(K)
])

torch.manual_seed(0)
np.random.seed(0)

tran = LinearTransformation(K=K, D=D, lags=lags, As=As)
observation = ARTruncatedNormalObservation(K=K,
                                           D=D,
                                           M=0,
                                           transformation=tran,
                                           bounds=bounds)

model = HMM(K=K, D=D, M=0, observation=observation)

lls = model.log_likelihood(data)
print(lls)

#losses_1, optimizer_1 = model_1.fit(data_1, method='adam', num_iters=2000, lr=0.001)

z_1 = model.most_likely_states(data)

x_predict_arr_lag1 = k_step_prediction(model, z_1, data)
コード例 #20
0
def rslt_saving(rslt_dir, model, Df, data, masks_a, masks_b, m_kwargs_a,
                m_kwargs_b, sample_T, train_model, losses, quiver_scale):

    tran = model.observation.transformation
    x_grids = tran.x_grids
    y_grids = tran.y_grids
    n_x = len(x_grids) - 1
    n_y = len(y_grids) - 1
    G = n_x * n_y
    f_corner_vec_func = tran.transformations_a[0].feature_vec_func
    K = model.K
    Df = Df

    #################### inference ###########################

    print("\ninferring most likely states...")
    z = model.most_likely_states(data,
                                 masks=(masks_a, masks_b),
                                 memory_kwargs_a=m_kwargs_a,
                                 memory_kwargs_b=m_kwargs_b)

    print("0 step prediction")
    if data.shape[0] <= 1000:
        data_to_predict = data
    else:
        data_to_predict = data[-1000:]
    x_predict = k_step_prediction_for_grid_model(model,
                                                 z,
                                                 data_to_predict,
                                                 memory_kwargs_a=m_kwargs_a,
                                                 memory_kwargs_b=m_kwargs_b)
    x_predict_err = np.mean(np.abs(x_predict - data_to_predict.numpy()),
                            axis=0)

    print("5 step prediction")
    x_predict_5 = k_step_prediction(model, z, data_to_predict, k=5)
    x_predict_5_err = np.mean(np.abs(x_predict_5 -
                                     data_to_predict[5:].numpy()),
                              axis=0)

    ################### samples #########################

    sample_z, sample_x = model.sample(sample_T)

    center_z = torch.tensor([0], dtype=torch.int)
    center_x = torch.tensor([[150, 190, 200, 200]], dtype=torch.float64)
    sample_z_center, sample_x_center = model.sample(sample_T,
                                                    prefix=(center_z,
                                                            center_x))

    ################## dynamics #####################

    # weights
    weights_a = np.array(
        [t.weights.detach().numpy() for t in tran.transformations_a])
    weights_b = np.array(
        [t.weights.detach().numpy() for t in tran.transformations_b])

    # dynamics
    grid_centers = np.array([[
        1 / 2 * (x_grids[i] + x_grids[i + 1]),
        1 / 2 * (y_grids[j] + y_grids[j + 1])
    ] for i in range(n_x) for j in range(n_y)])
    unit_corner_vecs = f_corner_vec_func(
        torch.tensor(grid_centers, dtype=torch.float64))
    unit_corner_vecs = unit_corner_vecs.numpy()
    # (G, 1, Df, d) * (G, K, Df, 1) --> (G, K, Df, d)
    weighted_corner_vecs_a = unit_corner_vecs[:, None] * weights_a[..., None]
    weighted_corner_vecs_b = unit_corner_vecs[:, None] * weights_b[..., None]

    grid_z_a_percentage = get_z_percentage_by_grid(masks_a, z, K, G)
    grid_z_b_percentage = get_z_percentage_by_grid(masks_b, z, K, G)

    # quiver
    XX, YY = np.meshgrid(np.linspace(20, 310, 30), np.linspace(0, 380, 30))
    XY = np.column_stack(
        (np.ravel(XX), np.ravel(YY)))  # shape (900,2) grid values
    XY_grids = np.concatenate((XY, XY), axis=1)

    XY_next = tran.transform(torch.tensor(XY_grids, dtype=torch.float64))
    dXY = XY_next.detach().numpy() - XY_grids[:, None]

    #################### saving ##############################

    print("begin saving...")

    # save summary
    avg_transform_speed = np.average(np.abs(dXY), axis=0)
    avg_sample_speed = np.average(np.abs(np.diff(sample_x, axis=0)), axis=0)
    avg_sample_center_speed = np.average(np.abs(
        np.diff(sample_x_center, axis=0)),
                                         axis=0)
    avg_data_speed = np.average(np.abs(np.diff(data.numpy(), axis=0)), axis=0)

    transition_matrix = model.transition.stationary_transition_matrix
    if transition_matrix.requires_grad:
        transition_matrix = transition_matrix.detach().numpy()
    else:
        transition_matrix = transition_matrix.numpy()
    summary_dict = {
        "init_dist": model.init_dist.detach().numpy(),
        "transition_matrix": transition_matrix,
        "x_predict_err": x_predict_err,
        "x_predict_5_err": x_predict_5_err,
        "variance": torch.exp(model.observation.log_sigmas).detach().numpy(),
        "log_likes": model.log_likelihood(data).detach().numpy(),
        "grid_z_a_percentage": grid_z_a_percentage,
        "grid_z_b_percentage": grid_z_b_percentage,
        "avg_transform_speed": avg_transform_speed,
        "avg_data_speed": avg_data_speed,
        "avg_sample_speed": avg_sample_speed,
        "avg_sample_center_speed": avg_sample_center_speed
    }
    with open(rslt_dir + "/summary.json", "w") as f:
        json.dump(summary_dict, f, indent=4, cls=NumpyEncoder)

    # save numbers
    saving_dict = {
        "z": z,
        "x_predict": x_predict,
        "x_predict_5": x_predict_5,
        "sample_z": sample_z,
        "sample_x": sample_x,
        "sample_z_center": sample_z_center,
        "sample_x_center": sample_x_center
    }

    if train_model:
        saving_dict['losses'] = losses
        plt.figure()
        plt.plot(losses)
        plt.savefig(rslt_dir + "/losses.jpg")
        plt.close()
    joblib.dump(saving_dict, rslt_dir + "/numbers")

    # save figures
    plot_z(z, K, title="most likely z for the ground truth")
    plt.savefig(rslt_dir + "/z.jpg")
    plt.close()

    if not os.path.exists(rslt_dir + "/samples"):
        os.makedirs(rslt_dir + "/samples")
        print("Making samples directory...")

    plot_z(sample_z, K, title="sample")
    plt.savefig(rslt_dir + "/samples/sample_z_{}.jpg".format(sample_T))
    plt.close()

    plot_z(sample_z_center, K, title="sample (starting from center)")
    plt.savefig(rslt_dir + "/samples/sample_z_center_{}.jpg".format(sample_T))
    plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(data,
               title="ground truth",
               xlim=[ARENA_XMIN - 20, ARENA_XMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/ground_truth.jpg")
    plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(sample_x,
               title="sample",
               xlim=[ARENA_XMIN - 20, ARENA_XMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/sample_x_{}.jpg".format(sample_T))
    plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(sample_x_center,
               title="sample (starting from center)",
               xlim=[ARENA_XMIN - 20, ARENA_XMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/sample_x_center_{}.jpg".format(sample_T))
    plt.close()

    plot_realdata_quiver(data, z, K, x_grids, y_grids, title="ground truth")
    plt.savefig(rslt_dir + "/samples/ground_truth_quiver.jpg")

    plot_realdata_quiver(sample_x,
                         sample_z,
                         K,
                         x_grids,
                         y_grids,
                         title="sample")
    plt.savefig(rslt_dir + "/samples/sample_x_quiver_{}.jpg".format(sample_T))
    plt.close()

    plot_realdata_quiver(sample_x_center,
                         sample_z_center,
                         K,
                         x_grids,
                         y_grids,
                         title="sample (starting from center)")
    plt.savefig(rslt_dir +
                "/samples/sample_x_center_quiver_{}.jpg".format(sample_T))
    plt.close()

    if not os.path.exists(rslt_dir + "/dynamics"):
        os.makedirs(rslt_dir + "/dynamics")
        print("Making dynamics directory...")

    plot_weights(weights_a,
                 Df,
                 K,
                 x_grids,
                 y_grids,
                 max_weight=tran.transformations_a[0].acc_factor,
                 title="weights (virgin)")
    plt.savefig(rslt_dir + "/dynamics/weights_a.jpg")
    plt.close()

    plot_weights(weights_b,
                 Df,
                 K,
                 x_grids,
                 y_grids,
                 max_weight=tran.transformations_b[0].acc_factor,
                 title="weights (mother)")
    plt.savefig(rslt_dir + "/dynamics/weights_b.jpg")
    plt.close()

    plot_dynamics(weighted_corner_vecs_a,
                  "virgin",
                  x_grids,
                  y_grids,
                  K=K,
                  scale=quiver_scale,
                  percentage=grid_z_a_percentage,
                  title="grid dynamics (virgin)")
    plt.savefig(rslt_dir + "/dynamics/dynamics_a.jpg")
    plt.close()

    plot_dynamics(weighted_corner_vecs_b,
                  "mother",
                  x_grids,
                  y_grids,
                  K=K,
                  scale=quiver_scale,
                  percentage=grid_z_b_percentage,
                  title="grid dynamics (mother)")
    plt.savefig(rslt_dir + "/dynamics/dynamics_b.jpg")
    plt.close()

    plot_quiver(XY_grids[:, 0:2],
                dXY[..., 0:2],
                'virgin',
                K=K,
                scale=quiver_scale,
                alpha=0.9,
                title="quiver (virgin)")
    plt.savefig(rslt_dir + "/dynamics/quiver_a.jpg")
    plt.close()

    plot_quiver(XY_grids[:, 2:4],
                dXY[..., 2:4],
                'mother',
                K=K,
                scale=quiver_scale,
                alpha=0.9,
                title="quiver (mother)")
    plt.savefig(rslt_dir + "/dynamics/quiver_b.jpg")
    plt.close()

    if not os.path.exists(rslt_dir + "/distributions"):
        os.makedirs(rslt_dir + "/distributions")
        print("Making distributions directory...")

    data_angles_a, data_angles_b = get_all_angles(data, x_grids, y_grids)
    sample_angles_a, sample_angles_b = get_all_angles(sample_x, x_grids,
                                                      y_grids)
    sample_x_center_angles_a, sample_x_center_angles_b = get_all_angles(
        sample_x_center, x_grids, y_grids)

    plot_list_of_angles(
        [data_angles_a, sample_angles_a, sample_x_center_angles_a],
        ['data', 'sample', 'sample_c'], "direction distribution (virgin)", n_x,
        n_y)
    plt.savefig(rslt_dir + "/distributions/angles_a.jpg")
    plt.close()
    plot_list_of_angles(
        [data_angles_b, sample_angles_b, sample_x_center_angles_b],
        ['data', 'sample', 'sample_c'], "direction distribution (mother)", n_x,
        n_y)
    plt.savefig(rslt_dir + "/distributions/angles_b.jpg")
    plt.close()

    data_speed_a, data_speed_b = get_speed(data, x_grids, y_grids)
    sample_speed_a, sample_speed_b = get_speed(sample_x, x_grids, y_grids)
    sample_x_center_speed_a, sample_x_center_speed_b = get_speed(
        sample_x_center, x_grids, y_grids)

    plot_list_of_speed([data_speed_a, sample_speed_a, sample_x_center_speed_a],
                       ['data', 'sample', 'sample_c'],
                       "speed distribution (virgin)", n_x, n_y)
    plt.savefig(rslt_dir + "/distributions/speed_a.jpg")
    plt.close()
    plot_list_of_speed([data_speed_b, sample_speed_b, sample_x_center_speed_b],
                       ['data', 'sample', 'sample_c'],
                       "speed distribution (mother)", n_x, n_y)
    plt.savefig(rslt_dir + "/distributions/speed_b.jpg")
    plt.close()

    try:
        if 100 < data.shape[0] <= 36000:
            plot_space_dist(data, x_grids, y_grids)
        elif data.shape[0] > 36000:
            plot_space_dist(data[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_data.jpg")
        plt.close()

        if 100 < sample_x.shape[0] <= 36000:
            plot_space_dist(sample_x, x_grids, y_grids)
        elif sample_x.shape[0] > 36000:
            plot_space_dist(sample_x[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_sample_x.jpg")
        plt.close()

        if 100 < sample_x_center.shape[0] <= 36000:
            plot_space_dist(sample_x_center, x_grids, y_grids)
        elif sample_x_center.shape[0] > 36000:
            plot_space_dist(sample_x_center[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_sample_x_center.jpg")
        plt.close()
    except:
        print("plot_space_dist unsuccessful")