def rslt_saving(rslt_dir, model, data, mouse, sample_T, train_model, losses,
                quiver_scale, x_grids, y_grids):
    tran = model.observation.transformation

    _, D = data.shape
    assert D == 2 or D == 4, "D must be either 2 or 4."

    n_x = len(x_grids) - 1
    n_y = len(y_grids) - 1

    K = model.K

    #################### inference ###########################

    print("\ninferring most likely states...")
    z = model.most_likely_states(data)

    print("0 step prediction")
    if data.shape[0] <= 5000:
        data_to_predict = data
    else:
        data_to_predict = data[-5000:]
    x_predict = k_step_prediction(model, z, data_to_predict)
    x_predict_err = np.mean(np.abs(x_predict - data_to_predict.numpy()),
                            axis=0)

    print("5 step prediction")
    x_predict_5 = k_step_prediction(model, z, data_to_predict, k=5)
    x_predict_5_err = np.mean(np.abs(x_predict_5 -
                                     data_to_predict[5:].numpy()),
                              axis=0)

    ################### samples #########################

    sample_z, sample_x = model.sample(sample_T)

    center_z = torch.tensor([0], dtype=torch.int)
    if D == 4:
        center_x = torch.tensor([[150, 190, 200, 200]], dtype=torch.float64)
    else:
        center_x = torch.tensor([[150, 190]], dtype=torch.float64)
    sample_z_center, sample_x_center = model.sample(sample_T,
                                                    prefix=(center_z,
                                                            center_x))

    ################## dynamics #####################

    # quiver
    XX, YY = np.meshgrid(np.linspace(20, 310, 30), np.linspace(0, 380, 30))
    XY = np.column_stack(
        (np.ravel(XX), np.ravel(YY)))  # shape (900,2) grid values
    if D == 2:
        XY_grids = XY
    else:
        XY_grids = np.concatenate((XY, XY), axis=1)

    XY_next = tran.transform(torch.tensor(XY_grids, dtype=torch.float64))
    dXY = XY_next.detach().numpy() - XY_grids[:, None]

    #################### saving ##############################

    print("begin saving...")

    # save summary
    avg_transform_speed = np.average(np.abs(dXY), axis=0)
    avg_sample_speed = np.average(np.abs(np.diff(sample_x, axis=0)), axis=0)
    avg_sample_center_speed = np.average(np.abs(
        np.diff(sample_x_center, axis=0)),
                                         axis=0)
    avg_data_speed = np.average(np.abs(np.diff(data.numpy(), axis=0)), axis=0)

    transition_matrix = model.transition.stationary_transition_matrix
    if transition_matrix.requires_grad:
        transition_matrix = transition_matrix.detach().numpy()
    else:
        transition_matrix = transition_matrix.numpy()

    cluster_centers = get_np(tran.mus_loc)

    summary_dict = {
        "init_dist": model.init_dist.detach().numpy(),
        "transition_matrix": transition_matrix,
        "x_predict_err": x_predict_err,
        "x_predict_5_err": x_predict_5_err,
        "mus": cluster_centers,
        "variance": torch.exp(model.observation.log_sigmas).detach().numpy(),
        "log_likes": model.log_likelihood(data).detach().numpy(),
        "avg_transform_speed": avg_transform_speed,
        "avg_data_speed": avg_data_speed,
        "avg_sample_speed": avg_sample_speed,
        "avg_sample_center_speed": avg_sample_center_speed
    }
    with open(rslt_dir + "/summary.json", "w") as f:
        json.dump(summary_dict, f, indent=4, cls=NumpyEncoder)

    # save numbers
    saving_dict = {
        "z": z,
        "x_predict": x_predict,
        "x_predict_5": x_predict_5,
        "sample_z": sample_z,
        "sample_x": sample_x,
        "sample_z_center": sample_z_center,
        "sample_x_center": sample_x_center
    }

    if train_model:
        saving_dict['losses'] = losses
        plt.figure()
        plt.plot(losses)
        plt.savefig(rslt_dir + "/losses.jpg")
        plt.close()
    joblib.dump(saving_dict, rslt_dir + "/numbers")

    # save figures
    plot_z(z, K, title="most likely z for the ground truth")
    plt.savefig(rslt_dir + "/z.jpg")
    plt.close()

    if not os.path.exists(rslt_dir + "/samples"):
        os.makedirs(rslt_dir + "/samples")
        print("Making samples directory...")

    plot_z(sample_z, K, title="sample")
    plt.savefig(rslt_dir + "/samples/sample_z_{}.jpg".format(sample_T))
    plt.close()

    plot_z(sample_z_center, K, title="sample (starting from center)")
    plt.savefig(rslt_dir + "/samples/sample_z_center_{}.jpg".format(sample_T))
    plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(data,
               title="ground truth_{}".format(mouse),
               xlim=[ARENA_XMIN - 20, ARENA_XMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/ground_truth.jpg")
    plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(sample_x,
               title="sample_{}".format(mouse),
               xlim=[ARENA_XMIN - 20, ARENA_XMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/sample_x_{}.jpg".format(sample_T))
    plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(sample_x_center,
               title="sample (starting from center)_{}".format(mouse),
               xlim=[ARENA_XMIN - 20, ARENA_XMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/sample_x_center_{}.jpg".format(sample_T))
    plt.close()

    plot_realdata_quiver(data,
                         z,
                         K,
                         x_grids,
                         y_grids,
                         title="ground truth",
                         cluster_centers=cluster_centers)
    plt.savefig(rslt_dir + "/samples/quiver_ground_truth.jpg", dpi=200)

    plot_realdata_quiver(sample_x,
                         sample_z,
                         K,
                         x_grids,
                         y_grids,
                         title="sample",
                         cluster_centers=cluster_centers)
    plt.savefig(rslt_dir + "/samples/quiver_sample_x_{}.jpg".format(sample_T),
                dpi=200)
    plt.close()

    plot_realdata_quiver(sample_x_center,
                         sample_z_center,
                         K,
                         x_grids,
                         y_grids,
                         title="sample (starting from center)",
                         cluster_centers=cluster_centers)
    plt.savefig(rslt_dir +
                "/samples/quiver_sample_x_center_{}.jpg".format(sample_T),
                dpi=200)
    plt.close()

    # plot mus
    plot_cluster_centers(cluster_centers, x_grids, y_grids)
    plt.savefig(rslt_dir + "/samples/cluster_centers.jpg", dpi=200)

    if not os.path.exists(rslt_dir + "/dynamics"):
        os.makedirs(rslt_dir + "/dynamics")
        print("Making dynamics directory...")

    if D == 2:
        plot_quiver(XY_grids,
                    dXY,
                    mouse,
                    K=K,
                    scale=quiver_scale,
                    alpha=0.9,
                    title="quiver ({})".format(mouse),
                    x_grids=x_grids,
                    y_grids=y_grids,
                    grid_alpha=0.2)
        plt.savefig(rslt_dir + "/dynamics/quiver_{}.jpg".format(mouse),
                    dpi=200)
        plt.close()
    else:
        plot_quiver(XY_grids[:, 0:2],
                    dXY[..., 0:2],
                    'virgin',
                    K=K,
                    scale=quiver_scale,
                    alpha=0.9,
                    title="quiver (virgin)",
                    x_grids=x_grids,
                    y_grids=y_grids,
                    grid_alpha=0.2)
        plt.savefig(rslt_dir + "/dynamics/quiver_a.jpg", dpi=200)
        plt.close()

        plot_quiver(XY_grids[:, 2:4],
                    dXY[..., 2:4],
                    'mother',
                    K=K,
                    scale=quiver_scale,
                    alpha=0.9,
                    title="quiver (mother)",
                    x_grids=x_grids,
                    y_grids=y_grids,
                    grid_alpha=0.2)
        plt.savefig(rslt_dir + "/dynamics/quiver_b.jpg", dpi=200)
        plt.close()

    if not os.path.exists(rslt_dir + "/distributions"):
        os.makedirs(rslt_dir + "/distributions")
        print("Making distributions directory...")

    if D == 4:
        data_angles_a, data_angles_b = get_all_angles(data, x_grids, y_grids)
        sample_angles_a, sample_angles_b = get_all_angles(
            sample_x, x_grids, y_grids)
        sample_x_center_angles_a, sample_x_center_angles_b = get_all_angles(
            sample_x_center, x_grids, y_grids)

        plot_list_of_angles(
            [data_angles_a, sample_angles_a, sample_x_center_angles_a],
            ['data', 'sample', 'sample_c'], "direction distribution (virgin)",
            n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/angles_a.jpg")
        plt.close()
        plot_list_of_angles(
            [data_angles_b, sample_angles_b, sample_x_center_angles_b],
            ['data', 'sample', 'sample_c'], "direction distribution (mother)",
            n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/angles_b.jpg")
        plt.close()

        data_speed_a, data_speed_b = get_speed(data, x_grids, y_grids)
        sample_speed_a, sample_speed_b = get_speed(sample_x, x_grids, y_grids)
        sample_x_center_speed_a, sample_x_center_speed_b = get_speed(
            sample_x_center, x_grids, y_grids)

        plot_list_of_speed(
            [data_speed_a, sample_speed_a, sample_x_center_speed_a],
            ['data', 'sample', 'sample_c'], "speed distribution (virgin)", n_x,
            n_y)
        plt.savefig(rslt_dir + "/distributions/speed_a.jpg")
        plt.close()
        plot_list_of_speed(
            [data_speed_b, sample_speed_b, sample_x_center_speed_b],
            ['data', 'sample', 'sample_c'], "speed distribution (mother)", n_x,
            n_y)
        plt.savefig(rslt_dir + "/distributions/speed_b.jpg")
        plt.close()
    else:
        data_angles_a = get_all_angles(data, x_grids, y_grids)
        sample_angles_a = get_all_angles(sample_x, x_grids, y_grids)
        sample_x_center_angles_a = get_all_angles(sample_x_center, x_grids,
                                                  y_grids)

        plot_list_of_angles(
            [data_angles_a, sample_angles_a, sample_x_center_angles_a],
            ['data', 'sample', 'sample_c'], "direction distribution (virgin)",
            n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/angles_{}.jpg".format(mouse))
        plt.close()

        data_speed_a = get_speed(data, x_grids, y_grids)
        sample_speed_a = get_speed(sample_x, x_grids, y_grids)
        sample_x_center_speed_a = get_speed(sample_x_center, x_grids, y_grids)

        plot_list_of_speed(
            [data_speed_a, sample_speed_a, sample_x_center_speed_a],
            ['data', 'sample', 'sample_c'], "speed distribution (virgin)", n_x,
            n_y)
        plt.savefig(rslt_dir + "/distributions/speed_{}.jpg".format(mouse))
        plt.close()

    try:
        if 100 < data.shape[0] <= 36000:
            plot_space_dist(data, x_grids, y_grids)
        elif data.shape[0] > 36000:
            plot_space_dist(data[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_data.jpg")
        plt.close()

        if 100 < sample_x.shape[0] <= 36000:
            plot_space_dist(sample_x, x_grids, y_grids)
        elif sample_x.shape[0] > 36000:
            plot_space_dist(sample_x[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_sample_x.jpg")
        plt.close()

        if 100 < sample_x_center.shape[0] <= 36000:
            plot_space_dist(sample_x_center, x_grids, y_grids)
        elif sample_x_center.shape[0] > 36000:
            plot_space_dist(sample_x_center[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_sample_x_center.jpg")
        plt.close()
    except:
        print("plot_space_dist unsuccessful")
def rslt_saving(rslt_dir, model, data, animal, memory_kwargs, list_of_k_steps, sample_T,
                quiver_scale, x_grids=None, y_grids=None,
                valid_data=None,
                transition_memory_kwargs=None,
                valid_data_transition_memory_kwargs = None,
                valid_data_memory_kwargs=None, device=torch.device('cpu')):

    transition_memory_kwargs = transition_memory_kwargs if transition_memory_kwargs else {}
    valid_data_transition_memory_kwargs = \
        valid_data_transition_memory_kwargs if valid_data_transition_memory_kwargs else {}
    valid_data_memory_kwargs = valid_data_memory_kwargs if valid_data_memory_kwargs else {}

    obs = model.observation
    if animal == 'both':
        assert isinstance(obs, GPObservation), type(obs)
    else:
        assert isinstance(obs, GPObservationSingle), type(obs)
    if x_grids is None or y_grids is None:
        x_grids = obs.x_grids
        y_grids = obs.y_grids
    n_x = len(x_grids) - 1
    n_y = len(y_grids) - 1

    K = model.K

    memory_kwargs = memory_kwargs if memory_kwargs else {}
    valid_data_memory_kwargs = valid_data_memory_kwargs if valid_data_memory_kwargs else {}

    #################### inference ###########################

    print("\ninferring most likely states...")
    z = model.most_likely_states(data, transition_mkwargs=transition_memory_kwargs, **memory_kwargs)
    z_valid = model.most_likely_states(valid_data, transition_mkwargs=valid_data_transition_memory_kwargs,
                                       **valid_data_memory_kwargs)

    # TODO: address valida_data = None
    print("0 step prediction")
    # TODO: fix kwargs not matching error
    data_to_predict = data

    x_predict = k_step_prediction_for_gpmodel(model, z, data_to_predict, **memory_kwargs)
    x_predict_valid = k_step_prediction_for_gpmodel(model, z_valid, valid_data, **valid_data_memory_kwargs)
    x_predict_err = np.mean(np.abs(x_predict - get_np(data_to_predict)), axis=0)
    if len(valid_data) == 0:
        x_predict_valid_err = None
    else:
        x_predict_valid_err = np.mean(np.abs(x_predict_valid - get_np(valid_data)), axis=0)

    dict_of_x_predict_k = dict(x_predict_0=x_predict, x_predict_v_0=x_predict_valid)
    dict_of_x_predict_k_err = dict(x_predict_0_err=x_predict_err, x_predict_v_0_err=x_predict_valid_err)

    for k_step in list_of_k_steps:
        print("{} step prediction".format(k_step))
        x_predict_k = k_step_prediction(model, z, data_to_predict, k=k_step)
        x_predict_valid_k = k_step_prediction(model, z, data_to_predict, k=k_step)
        x_predict_k_err = np.mean(np.abs(x_predict_k - get_np(data_to_predict[k_step:])), axis=0)
        if len(valid_data) == 0:
            x_predict_valid_k_err = None
        else:
            x_predict_valid_k_err = np.mean(np.abs(x_predict_valid_k - get_np(valid_data[k_step:])), axis=0)
        dict_of_x_predict_k["x_predict_{}".format(k_step)] = x_predict_k
        dict_of_x_predict_k["x_predict_v_{}".format(k_step)] = x_predict_valid_k
        dict_of_x_predict_k_err["x_predict_{}_err".format(k_step)] = x_predict_k_err
        dict_of_x_predict_k_err["x_predict_v_{}_err".format(k_step)] = x_predict_valid_k_err


    ################### samples #########################
    print("sampling")
    center_z = torch.tensor([0], dtype=torch.int, device=device)
    if animal == 'both':
        center_x = torch.tensor([[150, 190, 200, 200]], dtype=torch.float64, device=device)
    else:
        center_x = torch.tensor([[150, 190]], dtype=torch.float64, device=device)

    sample_z, sample_x = model.sample(sample_T)
    sample_z_center, sample_x_center = model.sample(sample_T, prefix=(center_z, center_x))


    ################## dynamics #####################
    print("dynamics")
    # quiver
    XX, YY = np.meshgrid(np.linspace(20, 310, 30),
                         np.linspace(0, 380, 30))
    XY_grids = np.column_stack((np.ravel(XX), np.ravel(YY)))  # shape (900,2) grid values
    if animal == 'both':
        XY_grids = np.concatenate((XY_grids, XY_grids), axis=1)  # (900, 4)

    # TODO, fix not based on z
    if animal == 'both':
        XY_next_a, _ = obs.get_mu_and_cov_for_single_animal(XY_grids[:,0:2], 0, mu_only=True)
        XY_next_b, _ = obs.get_mu_and_cov_for_single_animal(XY_grids[:,2:4], 1, mu_only=True)
        XY_next = torch.cat((XY_next_a, XY_next_b), dim=-1)
        dXY = get_np(XY_next) - XY_grids[:, None]
    else:
        XY_next, _ = obs.get_mu_and_cov_for_single_animal(XY_grids, mu_only=True)
        dXY = get_np(XY_next) - XY_grids[:, None]

    #################### saving ##############################

    print("begin saving...")

    # save summary
    avg_transform_speed = np.average(np.abs(dXY), axis=0)
    avg_sample_speed = np.average(np.abs(np.diff(sample_x, axis=0)), axis=0)
    avg_sample_center_speed = np.average(np.abs(np.diff(sample_x_center, axis=0)), axis=0)
    avg_data_speed = np.average(np.abs(np.diff(get_np(data), axis=0)), axis=0)

    if isinstance(model.transition, StationaryTransition):
        transition_matrix = model.transition.stationary_transition_matrix
    elif isinstance(model.transition, GridTransition):
        transition_matrix = model.transition.grid_transition_matrix
    else:
        raise ValueError("unsupported transition matrix type: {}".format(type(model.transition)))

    transition_matrix = get_np(transition_matrix)

    summary_dict = {"init_dist": get_np(model.init_dist),
                    "transition_matrix": transition_matrix,
                    "variance": get_np(torch.exp(model.observation.log_sigmas)),
                    "log_likes": get_np(model.log_likelihood(data, **memory_kwargs)),
                    "avg_data_speed": avg_data_speed,
                    "avg_sample_speed": avg_sample_speed, "avg_sample_center_speed": avg_sample_center_speed}
    summary_dict = {**dict_of_x_predict_k_err, **summary_dict}
    if len(valid_data) > 0:
        summary_dict["valid_log_likes"] = get_np(model.log_likelihood(valid_data, **valid_data_memory_kwargs))

    summary_dict["rs"] = get_np(model.observation.rs)
    summary_dict["avg_transform_speed"] = avg_transform_speed
    with open(rslt_dir + "/summary.json", "w") as f:
        json.dump(summary_dict, f, indent=4, cls=NumpyEncoder)

    # save numbers
    saving_dict = {"z": z, "z_valid": z_valid, "sample_z": sample_z, "sample_x": sample_x,
                   "sample_z_center": sample_z_center, "sample_x_center": sample_x_center}
    saving_dict = {**dict_of_x_predict_k, **saving_dict}

    joblib.dump(saving_dict, rslt_dir + "/numbers")

    # save figures
    if model.D == 2 and isinstance(model.transition, GridTransition):
        plot_grid_transition(n_x, n_y, model.transition.grid_transition_matrix)
        plt.savefig(rslt_dir + "/grid_transition.jpg")
        plt.close()

    if not os.path.exists(rslt_dir + "/samples"):
        os.makedirs(rslt_dir + "/samples")
        print("Making samples directory...")

    if K > 1:
        plot_z(z, K, title="most likely z for the ground truth")
        plt.savefig(rslt_dir + "/z.jpg")
        plt.close()

        if len(valid_data) > 0:
            plot_z(z_valid, K, title="most likely z for valid data")
            plt.savefig(rslt_dir + "/z_valid.jpg")
            plt.close()

        plot_z(sample_z, K, title="sample")
        plt.savefig(rslt_dir + "/samples/sample_z_{}.jpg".format(sample_T))
        plt.close()

        plot_z(sample_z_center, K, title="sample (starting from center)")
        plt.savefig(rslt_dir + "/samples/sample_z_center_{}.jpg".format(sample_T))
        plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(data, title="ground truth (training)", xlim=[ARENA_XMIN - 20, ARENA_YMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/ground_truth.jpg")
    plt.close()

    if len(valid_data) > 0:
        plt.figure(figsize=(4, 4))
        plot_mouse(valid_data, title="ground truth (valid)", xlim=[ARENA_XMIN - 20, ARENA_YMAX + 20],
                   ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
        plt.legend()
        plt.savefig(rslt_dir + "/samples/ground_truth_valid.jpg")
        plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(sample_x, title="sample", xlim=[ARENA_XMIN - 20, ARENA_YMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/sample_x_{}.jpg".format(sample_T))
    plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(sample_x_center, title="sample (starting from center)", xlim=[ARENA_XMIN - 20, ARENA_YMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/sample_x_center_{}.jpg".format(sample_T))
    plt.close()

    if K > 1:
        plot_realdata_quiver(data, z, K, x_grids, y_grids, title="ground truth (training)")
        plt.savefig(rslt_dir + "/samples/quiver_ground_truth.jpg", dpi=200)
        plt.close()

        if len(valid_data) > 0:
            plot_realdata_quiver(valid_data, z_valid, K, x_grids, y_grids, title="ground truth (valid)")
            plt.savefig(rslt_dir + "/samples/quiver_ground_truth_valid.jpg", dpi=200)
            plt.close()

        plot_realdata_quiver(sample_x, sample_z, K, x_grids, y_grids, title="sample")
        plt.savefig(rslt_dir + "/samples/quiver_sample_x_{}.jpg".format(sample_T), dpi=200)
        plt.close()

        plot_realdata_quiver(sample_x_center, sample_z_center, K, x_grids, y_grids, title="sample (starting from center)")
        plt.savefig(rslt_dir + "/samples/quiver_sample_x_center_{}.jpg".format(sample_T), dpi=200)
        plt.close()

    if not os.path.exists(rslt_dir + "/dynamics"):
        os.makedirs(rslt_dir + "/dynamics")
        print("Making dynamics directory...")

    if animal == 'both':
        plot_quiver(XY_grids[..., 0:2], dXY[..., 0:2], 'virgin', K=K, scale=quiver_scale, alpha=0.9,
                    title="quiver (virgin)", x_grids=x_grids, y_grids=y_grids, grid_alpha=0.2)
        plt.savefig(rslt_dir + "/dynamics/quiver_a.jpg", dpi=200)
        plt.close()

        plot_quiver(XY_grids[:, 2:4], dXY[..., 2:4], 'mother', K=K, scale=quiver_scale, alpha=0.9,
                    title="quiver (mother)", x_grids=x_grids, y_grids=y_grids, grid_alpha=0.2)
        plt.savefig(rslt_dir + "/dynamics/quiver_b.jpg", dpi=200)
        plt.close()
    else:
        plot_quiver(XY_grids, dXY, animal, K=K, scale=quiver_scale, alpha=0.9,
                    title="quiver ({})".format(animal), x_grids=x_grids, y_grids=y_grids, grid_alpha=0.2)
        plt.savefig(rslt_dir + "/dynamics/quiver_{}.jpg".format(animal), dpi=200)
        plt.close()

    if not os.path.exists(rslt_dir + "/distributions"):
        os.makedirs(rslt_dir + "/distributions")
        print("Making distributions directory...")

    # sanity checks
    plot_data_condition_on_all_zs(data, z, K, size=2, alpha=0.3)
    plt.savefig(rslt_dir + "/distributions/spatial_occup_groundtruth.jpg", dpi=100)
    plot_data_condition_on_all_zs(sample_x, sample_z, K, size=2, alpha=0.3)
    plt.savefig(rslt_dir + "/distributions/spatial_occup_sample_x.jpg", dpi=100)
    plot_data_condition_on_all_zs(sample_x_center, sample_z_center, K, size=2, alpha=0.3)
    plt.savefig(rslt_dir + "/distributions/spatial_occup_sample_x_center.jpg", dpi=100)

    plot_2d_time_plot_condition_on_all_zs(data, z, K, title='ground truth')
    plt.savefig(rslt_dir + "/distributions/4traces_groundtruth.jpg", dpi=100)
    plot_2d_time_plot_condition_on_all_zs(sample_x, sample_z, K, title='sample_x')
    plt.savefig(rslt_dir + "/distributions/4traces_sample_x.jpg", dpi=100)
    plot_2d_time_plot_condition_on_all_zs(sample_x_center, sample_z_center, K, title='sample_x_center')
    plt.savefig(rslt_dir + "/distributions/4traces_sample_x_center.jpg", dpi=100)

    data_angles = get_all_angles(data, x_grids, y_grids, device=device)
    sample_angles = get_all_angles(sample_x, x_grids, y_grids, device=device)
    sample_x_center_angles = get_all_angles(sample_x_center, x_grids, y_grids,
                                                                        device=device)

    if animal == 'both':
        plot_list_of_angles([data_angles[0], sample_angles[0], sample_x_center_angles],
                            ['data', 'sample', 'sample_c'], "direction distribution (virgin)", n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/angles_a.jpg")
        plt.close()
        plot_list_of_angles([data_angles[1], sample_angles[1], sample_x_center_angles],
                            ['data', 'sample', 'sample_c'], "direction distribution (mother)", n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/angles_b.jpg")
        plt.close()
    else:
        plot_list_of_angles([data_angles, sample_angles, sample_x_center_angles],
                            ['data', 'sample', 'sample_c'], "direction distribution ({})".format(animal), n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/angles_{}.jpg".format(animal))
        plt.close()

    data_speed = get_speed(data, x_grids, y_grids, device=device)
    sample_speed = get_speed(sample_x, x_grids, y_grids, device=device)
    sample_x_center_speed = get_speed(sample_x_center, x_grids, y_grids, device=device)

    if animal == 'both':
        plot_list_of_speed([data_speed[0], sample_speed[0], sample_x_center_speed[0]],
                           ['data', 'sample', 'sample_c'], "speed distribution (virgin)", n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/speed_a.jpg")
        plt.close()
        plot_list_of_speed([data_speed[1], sample_speed[1], sample_x_center_speed[1]],
                           ['data', 'sample', 'sample_c'], "speed distribution (mother)", n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/speed_b.jpg")
        plt.close()
    else:
        plot_list_of_speed([data_speed, sample_speed, sample_x_center_speed],
                           ['data', 'sample', 'sample_c'], "speed distribution ({})".format(animal), n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/speed_{}.jpg".format(animal))
        plt.close()

    try:
        if 100 < data.shape[0] <= 36000:
            plot_space_dist(data, x_grids, y_grids)
        elif data.shape[0] > 36000:
            plot_space_dist(data[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_data.jpg")
        plt.close()

        if 100 < sample_x.shape[0] <= 36000:
            plot_space_dist(sample_x, x_grids, y_grids)
        elif sample_x.shape[0] > 36000:
            plot_space_dist(sample_x[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_sample_x.jpg")
        plt.close()

        if 100 < sample_x_center.shape[0] <= 36000:
            plot_space_dist(sample_x_center, x_grids, y_grids)
        elif sample_x_center.shape[0] > 36000:
            plot_space_dist(sample_x_center[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_sample_x_center.jpg")
        plt.close()
    except:
        print("plot_space_dist unsuccessful")
def rslt_saving(rslt_dir, model, data, animal, memory_kwargs, list_of_k_steps, sample_T,
                quiver_scale, x_grids=None, y_grids=None, dynamics_T=None,
                valid_data=None, valid_data_memory_kwargs=None, device=torch.device('cpu')):

    valid_data_memory_kwargs = valid_data_memory_kwargs if valid_data_memory_kwargs else {}

    tran = model.observation.transformation
    if x_grids is None or y_grids is None:
        x_grids = tran.x_grids
        y_grids = tran.y_grids
    n_x = len(x_grids) - 1
    n_y = len(y_grids) - 1

    K = model.K

    memory_kwargs = memory_kwargs if memory_kwargs else {}


    #################### inference ###########################

    print("\ninferring most likely states...")
    z = model.most_likely_states(data, **memory_kwargs)
    z_valid = model.most_likely_states(valid_data, **valid_data_memory_kwargs)

    # TODO: address valida_data = None
    print("0 step prediction")
    # TODO: add valid data for other model
    if data.shape[0] <= 10000:
        data_to_predict = data
    else:
        data_to_predict = data[-10000:]
    if isinstance(tran, (LinearGridTransformation, SingleLinearGridTransformation)):
        x_predict = k_step_prediction_for_lineargrid_model(model, z, data_to_predict, **memory_kwargs)
        x_predict_valid = k_step_prediction_for_lineargrid_model(model, z_valid, valid_data, **valid_data_memory_kwargs)
    elif isinstance(tran, GPGridTransformation):
        x_predict = k_step_prediction_for_gpgrid_model(model, z, data_to_predict, **memory_kwargs)
        x_predict_valid = k_step_prediction_for_gpgrid_model(model, z_valid, valid_data, **valid_data_memory_kwargs)
    elif isinstance(tran, WeightedGridTransformation):
        x_predict = k_step_prediction_for_weightedgrid_model(model, z, data_to_predict, **memory_kwargs)
        x_predict_valid = k_step_prediction_for_weightedgrid_model(model, z_valid, valid_data, **valid_data_memory_kwargs)
    elif isinstance(tran, (LSTMTransformation, UniLSTMTransformation)):
        x_predict = k_step_prediction_for_lstm_model(model, z, data_to_predict,
                                                     feature_vecs=memory_kwargs["feature_vecs"])
        x_predict_valid = k_step_prediction_for_lstm_model(model, z_valid, valid_data,
                                                           feature_vecs=valid_data_memory_kwargs["feature_vecs"])
    elif isinstance(tran, LSTMBasedTransformation):
        x_predict = k_step_prediction_for_lstm_based_model(model, z, data_to_predict, k=0,
                                                           feature_vecs=memory_kwargs["feature_vecs"])
        x_predict_valid = k_step_prediction_for_lstm_based_model(model, z_valid, valid_data, k=0,
                                                           feature_vecs=valid_data_memory_kwargs["feature_vecs"])
    else:
        raise ValueError("Unsupported transformation!")
    x_predict_err = np.mean(np.abs(x_predict - get_np(data_to_predict)), axis=0)
    if len(valid_data) == 0:
        x_predict_valid_err = None
    else:
        x_predict_valid_err = np.mean(np.abs(x_predict_valid - get_np(valid_data)), axis=0)

    dict_of_x_predict_k = dict(x_predict_0=x_predict, x_predict_v_0=x_predict_valid)
    dict_of_x_predict_k_err = dict(x_predict_0_err=x_predict_err, x_predict_v_0_err=x_predict_valid_err)

    for k_step in list_of_k_steps:
        print("{} step prediction".format(k_step))
        if isinstance(tran, LSTMBasedTransformation):
            # TODO: take care of empty valid data
            x_predict_k = k_step_prediction_for_lstm_based_model(model, z, data_to_predict, k=k_step)
            x_predict_valid_k = k_step_prediction_for_lstm_model(model, z, data_to_predict, k=k_step)
        else:
            x_predict_k = k_step_prediction(model, z, data_to_predict, k=k_step)
            x_predict_valid_k = k_step_prediction(model, z_valid, valid_data, k=k_step)
        x_predict_k_err = np.mean(np.abs(x_predict_k - get_np(data_to_predict[k_step:])), axis=0)
        if len(valid_data) == 0:
            x_predict_valid_k_err = None
        else:
            x_predict_valid_k_err = np.mean(np.abs(x_predict_valid_k - get_np(valid_data[k_step:])), axis=0)
        dict_of_x_predict_k["x_predict_{}".format(k_step)] = x_predict_k
        dict_of_x_predict_k["x_predict_v_{}".format(k_step)] = x_predict_valid_k
        dict_of_x_predict_k_err["x_predict_{}_err".format(k_step)] = x_predict_k_err
        dict_of_x_predict_k_err["x_predict_v_{}_err".format(k_step)] = x_predict_valid_k_err


    ################### samples #########################
    print("sampling")
    center_z = torch.tensor([0], dtype=torch.int, device=device)
    if animal == "both":
        center_x = torch.tensor([[150, 190, 200, 200]], dtype=torch.float64, device=device)
    else:
        center_x = torch.tensor([[150, 190]], dtype=torch.float64, device=device)

    if isinstance(tran, LSTMBasedTransformation):
        lstm_states = {}
        sample_z, sample_x = model.sample(sample_T, lstm_states=lstm_states)

        lstm_states = {}
        sample_z_center, sample_x_center = model.sample(sample_T, prefix=(center_z, center_x), lstm_states=lstm_states)
    else:
        sample_z, sample_x = model.sample(sample_T)

        sample_z_center, sample_x_center = model.sample(sample_T, prefix=(center_z, center_x))


    ################## dynamics #####################

    if isinstance(tran, (LinearGridTransformation, SingleLinearGridTransformation,
                         GPGridTransformation, WeightedGridTransformation)):
        # quiver
        if animal == 'both':
            XX, YY = np.meshgrid(np.linspace(20, 310, 30),
                                 np.linspace(0, 380, 30))
            XY = np.column_stack((np.ravel(XX), np.ravel(YY)))  # shape (900,2) grid values
            XY_grids = np.concatenate((XY, XY), axis=1)  # (900, 4)
        else:
            XX, YY = np.meshgrid(np.linspace(20, 310, 30),
                                 np.linspace(0, 380, 30))
            XY_grids = np.column_stack((np.ravel(XX), np.ravel(YY)))  # shape (900,2) grid values

        XY_next = tran.transform(torch.tensor(XY_grids, dtype=torch.float64, device=device))
        dXY = get_np(XY_next) - XY_grids[:, None]

    # TODO: maybe use sample condition on z (transformation) to show the dynamics
    samples_on_fixed_zs = []
    if isinstance(tran, LSTMBasedTransformation):
        assert dynamics_T is not None
        for k in range(K):
            lstm_states = {}
            fixed_z = torch.ones(dynamics_T, dtype=torch.int) * k
            samples_on_fixed_z = model.sample_condition_on_zs(zs=fixed_z, transformation=True, return_np=True,
                                                              lstm_states=lstm_states)
            samples_on_fixed_zs.append(samples_on_fixed_z)

    #################### saving ##############################

    print("begin saving...")

    # save summary
    if isinstance(tran, (LinearGridTransformation, SingleLinearGridTransformation,
                         GPGridTransformation, WeightedGridTransformation)):
        avg_transform_speed = np.average(np.abs(dXY), axis=0)
    avg_sample_speed = np.average(np.abs(np.diff(sample_x, axis=0)), axis=0)
    avg_sample_center_speed = np.average(np.abs(np.diff(sample_x_center, axis=0)), axis=0)
    avg_data_speed = np.average(np.abs(np.diff(get_np(data), axis=0)), axis=0)

    if isinstance(model.transition, StationaryTransition):
        transition_matrix = model.transition.stationary_transition_matrix
    elif isinstance(model.transition, GridTransition):
        transition_matrix = model.transition.grid_transition_matrix
    else:
        raise ValueError("unsupported transition matrix type: {}".format(type(model.transition)))

    transition_matrix = get_np(transition_matrix)

    summary_dict = {"init_dist": get_np(model.init_dist),
                    "transition_matrix": transition_matrix,
                    "variance": get_np(torch.exp(model.observation.log_sigmas)),
                    "log_likes": get_np(model.log_likelihood(data, **memory_kwargs)),
                    "avg_data_speed": avg_data_speed,
                    "avg_sample_speed": avg_sample_speed, "avg_sample_center_speed": avg_sample_center_speed}
    summary_dict = {**dict_of_x_predict_k_err, **summary_dict}
    if len(valid_data) > 0:
        summary_dict["valid_log_likes"] = get_np(model.log_likelihood(valid_data, **valid_data_memory_kwargs))
    if isinstance(tran, GPGridTransformation):
        summary_dict["real_rs"] = get_np(tran.rs_factor * torch.sigmoid(tran.rs))
    if isinstance(tran, WeightedGridTransformation):
        summary_dict["beta"] = get_np(tran.beta)
    if isinstance(tran, (LinearGridTransformation, GPGridTransformation, WeightedGridTransformation)):
        summary_dict["avg_transform_speed"] = avg_transform_speed
    with open(rslt_dir + "/summary.json", "w") as f:
        json.dump(summary_dict, f, indent=4, cls=NumpyEncoder)

    # save numbers
    saving_dict = {"z": z, "z_valid": z_valid, "sample_z": sample_z, "sample_x": sample_x,
                   "sample_z_center": sample_z_center, "sample_x_center": sample_x_center}
    saving_dict = {**dict_of_x_predict_k, **saving_dict}
    if isinstance(tran, LSTMBasedTransformation):
        saving_dict["samples_on_fixed_zs"] = samples_on_fixed_zs

    joblib.dump(saving_dict, rslt_dir + "/numbers")

    # save figures
    if model.D == 2 and isinstance(model.transition, GridTransition):
        plot_grid_transition(n_x, n_y, model.transition.grid_transition_matrix)
        plt.savefig(rslt_dir + "/grid_transition.jpg")
        plt.close()


    plot_z(z, K, title="most likely z for the ground truth")
    plt.savefig(rslt_dir + "/z.jpg")
    plt.close()

    if len(valid_data) >0:
        plot_z(z_valid, K, title="most likely z for valid data")
        plt.savefig(rslt_dir + "/z_valid.jpg")
        plt.close()

    if not os.path.exists(rslt_dir + "/samples"):
        os.makedirs(rslt_dir + "/samples")
        print("Making samples directory...")

    plot_z(sample_z, K, title="sample")
    plt.savefig(rslt_dir + "/samples/sample_z_{}.jpg".format(sample_T))
    plt.close()

    plot_z(sample_z_center, K, title="sample (starting from center)")
    plt.savefig(rslt_dir + "/samples/sample_z_center_{}.jpg".format(sample_T))
    plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(data, title="ground truth (training)", xlim=[ARENA_XMIN - 20, ARENA_YMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/ground_truth.jpg")
    plt.close()

    if len(valid_data) > 0:
        plt.figure(figsize=(4, 4))
        plot_mouse(valid_data, title="ground truth (valid)", xlim=[ARENA_XMIN - 20, ARENA_YMAX + 20],
                   ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
        plt.legend()
        plt.savefig(rslt_dir + "/samples/ground_truth_valid.jpg")
        plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(sample_x, title="sample", xlim=[ARENA_XMIN - 20, ARENA_YMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/sample_x_{}.jpg".format(sample_T))
    plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(sample_x_center, title="sample (starting from center)", xlim=[ARENA_XMIN - 20, ARENA_YMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/sample_x_center_{}.jpg".format(sample_T))
    plt.close()

    plot_realdata_quiver(data, z, K, x_grids, y_grids, title="ground truth (training)")
    plt.savefig(rslt_dir + "/samples/quiver_ground_truth.jpg", dpi=200)
    plt.close()

    if len(valid_data) > 0:
        plot_realdata_quiver(valid_data, z_valid, K, x_grids, y_grids, title="ground truth (valid)")
        plt.savefig(rslt_dir + "/samples/quiver_ground_truth_valid.jpg", dpi=200)
        plt.close()

    plot_realdata_quiver(sample_x, sample_z, K, x_grids, y_grids, title="sample")
    plt.savefig(rslt_dir + "/samples/quiver_sample_x_{}.jpg".format(sample_T), dpi=200)
    plt.close()

    plot_realdata_quiver(sample_x_center, sample_z_center, K, x_grids, y_grids, title="sample (starting from center)")
    plt.savefig(rslt_dir + "/samples/quiver_sample_x_center_{}.jpg".format(sample_T), dpi=200)
    plt.close()

    if isinstance(tran, (LinearGridTransformation, SingleLinearGridTransformation,
                         GPGridTransformation, WeightedGridTransformation)):
        if not os.path.exists(rslt_dir + "/dynamics"):
            os.makedirs(rslt_dir + "/dynamics")
            print("Making dynamics directory...")

        if animal == 'both':
            plot_quiver(XY_grids[:, 0:2], dXY[..., 0:2], 'virgin', K=K, scale=quiver_scale, alpha=0.9,
                        title="quiver (virgin)", x_grids=x_grids, y_grids=y_grids, grid_alpha=0.2)
            plt.savefig(rslt_dir + "/dynamics/quiver_a.jpg", dpi=200)
            plt.close()

            plot_quiver(XY_grids[:, 2:4], dXY[..., 2:4], 'mother', K=K, scale=quiver_scale, alpha=0.9,
                        title="quiver (mother)", x_grids=x_grids, y_grids=y_grids, grid_alpha=0.2)
            plt.savefig(rslt_dir + "/dynamics/quiver_b.jpg", dpi=200)
            plt.close()
        else:
            plot_quiver(XY_grids, dXY, animal, K=K, scale=quiver_scale, alpha=0.9,
                        title="quiver ({})".format(animal), x_grids=x_grids, y_grids=y_grids, grid_alpha=0.2)
            plt.savefig(rslt_dir + "/dynamics/quiver_{}.jpg".format(animal), dpi=200)
            plt.close()

    elif isinstance(tran, LSTMBasedTransformation):
        if not os.path.exists(rslt_dir + "/dynamics"):
            os.makedirs(rslt_dir + "/dynamics")
            print("Making dynamics directory...")

        for k in range(K):
            plot_realdata_quiver(samples_on_fixed_zs[k], np.ones(dynamics_T, dtype=np.int)*k, K, x_grids=x_grids, y_grids=y_grids,
                                 title="sample conditioned on k={}".format(k))
            plt.savefig(rslt_dir + "/dynamics/samples_on_k{}.jpg".format(k), dpi=200)
            plt.close()

    if not os.path.exists(rslt_dir + "/distributions"):
        os.makedirs(rslt_dir + "/distributions")
        print("Making distributions directory...")

    # sanity checks
    plot_data_condition_on_all_zs(data, z, K, size=2, alpha=0.3)
    plt.savefig(rslt_dir + "/distributions/spatial_occup_groundtruth.jpg", dpi=100)
    plot_data_condition_on_all_zs(sample_x, sample_z, K, size=2, alpha=0.3)
    plt.savefig(rslt_dir + "/distributions/spatial_occup_sample_x.jpg", dpi=100)
    plot_data_condition_on_all_zs(sample_x_center, sample_z_center, K, size=2, alpha=0.3)
    plt.savefig(rslt_dir + "/distributions/spatial_occup_sample_x_center.jpg", dpi=100)

    plot_2d_time_plot_condition_on_all_zs(data, z, K, title='ground truth')
    plt.savefig(rslt_dir + "/distributions/4traces_groundtruth.jpg", dpi=100)
    plot_2d_time_plot_condition_on_all_zs(sample_x, sample_z, K, title='sample_x')
    plt.savefig(rslt_dir + "/distributions/4traces_sample_x.jpg", dpi=100)
    plot_2d_time_plot_condition_on_all_zs(sample_x_center, sample_z_center, K, title='sample_x_center')
    plt.savefig(rslt_dir + "/distributions/4traces_sample_x_center.jpg", dpi=100)

    data_angles = get_all_angles(data, x_grids, y_grids, device=device)
    sample_angles = get_all_angles(sample_x, x_grids, y_grids, device=device)
    sample_x_center_angles = get_all_angles(sample_x_center, x_grids, y_grids,
                                            device=device)

    if animal == 'both':
        plot_list_of_angles([data_angles[0], sample_angles[0], sample_x_center_angles],
                            ['data', 'sample', 'sample_c'], "direction distribution (virgin)", n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/angles_a.jpg")
        plt.close()
        plot_list_of_angles([data_angles[1], sample_angles[1], sample_x_center_angles],
                            ['data', 'sample', 'sample_c'], "direction distribution (mother)", n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/angles_b.jpg")
        plt.close()
    else:
        plot_list_of_angles([data_angles, sample_angles, sample_x_center_angles],
                            ['data', 'sample', 'sample_c'], "direction distribution ({})".format(animal), n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/angles_{}.jpg".format(animal))
        plt.close()

    data_speed = get_speed(data, x_grids, y_grids, device=device)
    sample_speed = get_speed(sample_x, x_grids, y_grids, device=device)
    sample_x_center_speed = get_speed(sample_x_center, x_grids, y_grids, device=device)

    if animal == 'both':
        plot_list_of_speed([data_speed[0], sample_speed[0], sample_x_center_speed[0]],
                           ['data', 'sample', 'sample_c'], "speed distribution (virgin)", n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/speed_a.jpg")
        plt.close()
        plot_list_of_speed([data_speed[1], sample_speed[1], sample_x_center_speed[1]],
                           ['data', 'sample', 'sample_c'], "speed distribution (mother)", n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/speed_b.jpg")
        plt.close()
    else:
        plot_list_of_speed([data_speed, sample_speed, sample_x_center_speed],
                           ['data', 'sample', 'sample_c'], "speed distribution ({})".format(animal), n_x, n_y)
        plt.savefig(rslt_dir + "/distributions/speed_{}.jpg".format(animal))
        plt.close()

    try:
        if 100 < data.shape[0] <= 36000:
            plot_space_dist(data, x_grids, y_grids)
        elif data.shape[0] > 36000:
            plot_space_dist(data[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_data.jpg")
        plt.close()

        if 100 < sample_x.shape[0] <= 36000:
            plot_space_dist(sample_x, x_grids, y_grids)
        elif sample_x.shape[0] > 36000:
            plot_space_dist(sample_x[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_sample_x.jpg")
        plt.close()

        if 100 < sample_x_center.shape[0] <= 36000:
            plot_space_dist(sample_x_center, x_grids, y_grids)
        elif sample_x_center.shape[0] > 36000:
            plot_space_dist(sample_x_center[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_sample_x_center.jpg")
        plt.close()
    except:
        print("plot_space_dist unsuccessful")
def rslt_saving(rslt_dir, model, Df, data, masks_a, masks_b, m_kwargs_a,
                m_kwargs_b, sample_T, train_model, losses, quiver_scale):

    tran = model.observation.transformation
    x_grids = tran.x_grids
    y_grids = tran.y_grids
    n_x = len(x_grids) - 1
    n_y = len(y_grids) - 1
    G = n_x * n_y
    f_corner_vec_func = tran.transformations_a[0].feature_vec_func
    K = model.K
    Df = Df

    #################### inference ###########################

    print("\ninferring most likely states...")
    z = model.most_likely_states(data,
                                 masks=(masks_a, masks_b),
                                 memory_kwargs_a=m_kwargs_a,
                                 memory_kwargs_b=m_kwargs_b)

    print("0 step prediction")
    if data.shape[0] <= 1000:
        data_to_predict = data
    else:
        data_to_predict = data[-1000:]
    x_predict = k_step_prediction_for_grid_model(model,
                                                 z,
                                                 data_to_predict,
                                                 memory_kwargs_a=m_kwargs_a,
                                                 memory_kwargs_b=m_kwargs_b)
    x_predict_err = np.mean(np.abs(x_predict - data_to_predict.numpy()),
                            axis=0)

    print("5 step prediction")
    x_predict_5 = k_step_prediction(model, z, data_to_predict, k=5)
    x_predict_5_err = np.mean(np.abs(x_predict_5 -
                                     data_to_predict[5:].numpy()),
                              axis=0)

    ################### samples #########################

    sample_z, sample_x = model.sample(sample_T)

    center_z = torch.tensor([0], dtype=torch.int)
    center_x = torch.tensor([[150, 190, 200, 200]], dtype=torch.float64)
    sample_z_center, sample_x_center = model.sample(sample_T,
                                                    prefix=(center_z,
                                                            center_x))

    ################## dynamics #####################

    # weights
    weights_a = np.array(
        [t.weights.detach().numpy() for t in tran.transformations_a])
    weights_b = np.array(
        [t.weights.detach().numpy() for t in tran.transformations_b])

    # dynamics
    grid_centers = np.array([[
        1 / 2 * (x_grids[i] + x_grids[i + 1]),
        1 / 2 * (y_grids[j] + y_grids[j + 1])
    ] for i in range(n_x) for j in range(n_y)])
    unit_corner_vecs = f_corner_vec_func(
        torch.tensor(grid_centers, dtype=torch.float64))
    unit_corner_vecs = unit_corner_vecs.numpy()
    # (G, 1, Df, d) * (G, K, Df, 1) --> (G, K, Df, d)
    weighted_corner_vecs_a = unit_corner_vecs[:, None] * weights_a[..., None]
    weighted_corner_vecs_b = unit_corner_vecs[:, None] * weights_b[..., None]

    grid_z_a_percentage = get_z_percentage_by_grid(masks_a, z, K, G)
    grid_z_b_percentage = get_z_percentage_by_grid(masks_b, z, K, G)

    # quiver
    XX, YY = np.meshgrid(np.linspace(20, 310, 30), np.linspace(0, 380, 30))
    XY = np.column_stack(
        (np.ravel(XX), np.ravel(YY)))  # shape (900,2) grid values
    XY_grids = np.concatenate((XY, XY), axis=1)

    XY_next = tran.transform(torch.tensor(XY_grids, dtype=torch.float64))
    dXY = XY_next.detach().numpy() - XY_grids[:, None]

    #################### saving ##############################

    print("begin saving...")

    # save summary
    avg_transform_speed = np.average(np.abs(dXY), axis=0)
    avg_sample_speed = np.average(np.abs(np.diff(sample_x, axis=0)), axis=0)
    avg_sample_center_speed = np.average(np.abs(
        np.diff(sample_x_center, axis=0)),
                                         axis=0)
    avg_data_speed = np.average(np.abs(np.diff(data.numpy(), axis=0)), axis=0)

    transition_matrix = model.transition.stationary_transition_matrix
    if transition_matrix.requires_grad:
        transition_matrix = transition_matrix.detach().numpy()
    else:
        transition_matrix = transition_matrix.numpy()
    summary_dict = {
        "init_dist": model.init_dist.detach().numpy(),
        "transition_matrix": transition_matrix,
        "x_predict_err": x_predict_err,
        "x_predict_5_err": x_predict_5_err,
        "variance": torch.exp(model.observation.log_sigmas).detach().numpy(),
        "log_likes": model.log_likelihood(data).detach().numpy(),
        "grid_z_a_percentage": grid_z_a_percentage,
        "grid_z_b_percentage": grid_z_b_percentage,
        "avg_transform_speed": avg_transform_speed,
        "avg_data_speed": avg_data_speed,
        "avg_sample_speed": avg_sample_speed,
        "avg_sample_center_speed": avg_sample_center_speed
    }
    with open(rslt_dir + "/summary.json", "w") as f:
        json.dump(summary_dict, f, indent=4, cls=NumpyEncoder)

    # save numbers
    saving_dict = {
        "z": z,
        "x_predict": x_predict,
        "x_predict_5": x_predict_5,
        "sample_z": sample_z,
        "sample_x": sample_x,
        "sample_z_center": sample_z_center,
        "sample_x_center": sample_x_center
    }

    if train_model:
        saving_dict['losses'] = losses
        plt.figure()
        plt.plot(losses)
        plt.savefig(rslt_dir + "/losses.jpg")
        plt.close()
    joblib.dump(saving_dict, rslt_dir + "/numbers")

    # save figures
    plot_z(z, K, title="most likely z for the ground truth")
    plt.savefig(rslt_dir + "/z.jpg")
    plt.close()

    if not os.path.exists(rslt_dir + "/samples"):
        os.makedirs(rslt_dir + "/samples")
        print("Making samples directory...")

    plot_z(sample_z, K, title="sample")
    plt.savefig(rslt_dir + "/samples/sample_z_{}.jpg".format(sample_T))
    plt.close()

    plot_z(sample_z_center, K, title="sample (starting from center)")
    plt.savefig(rslt_dir + "/samples/sample_z_center_{}.jpg".format(sample_T))
    plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(data,
               title="ground truth",
               xlim=[ARENA_XMIN - 20, ARENA_XMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/ground_truth.jpg")
    plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(sample_x,
               title="sample",
               xlim=[ARENA_XMIN - 20, ARENA_XMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/sample_x_{}.jpg".format(sample_T))
    plt.close()

    plt.figure(figsize=(4, 4))
    plot_mouse(sample_x_center,
               title="sample (starting from center)",
               xlim=[ARENA_XMIN - 20, ARENA_XMAX + 20],
               ylim=[ARENA_YMIN - 20, ARENA_YMAX + 20])
    plt.legend()
    plt.savefig(rslt_dir + "/samples/sample_x_center_{}.jpg".format(sample_T))
    plt.close()

    plot_realdata_quiver(data, z, K, x_grids, y_grids, title="ground truth")
    plt.savefig(rslt_dir + "/samples/ground_truth_quiver.jpg")

    plot_realdata_quiver(sample_x,
                         sample_z,
                         K,
                         x_grids,
                         y_grids,
                         title="sample")
    plt.savefig(rslt_dir + "/samples/sample_x_quiver_{}.jpg".format(sample_T))
    plt.close()

    plot_realdata_quiver(sample_x_center,
                         sample_z_center,
                         K,
                         x_grids,
                         y_grids,
                         title="sample (starting from center)")
    plt.savefig(rslt_dir +
                "/samples/sample_x_center_quiver_{}.jpg".format(sample_T))
    plt.close()

    if not os.path.exists(rslt_dir + "/dynamics"):
        os.makedirs(rslt_dir + "/dynamics")
        print("Making dynamics directory...")

    plot_weights(weights_a,
                 Df,
                 K,
                 x_grids,
                 y_grids,
                 max_weight=tran.transformations_a[0].acc_factor,
                 title="weights (virgin)")
    plt.savefig(rslt_dir + "/dynamics/weights_a.jpg")
    plt.close()

    plot_weights(weights_b,
                 Df,
                 K,
                 x_grids,
                 y_grids,
                 max_weight=tran.transformations_b[0].acc_factor,
                 title="weights (mother)")
    plt.savefig(rslt_dir + "/dynamics/weights_b.jpg")
    plt.close()

    plot_dynamics(weighted_corner_vecs_a,
                  "virgin",
                  x_grids,
                  y_grids,
                  K=K,
                  scale=quiver_scale,
                  percentage=grid_z_a_percentage,
                  title="grid dynamics (virgin)")
    plt.savefig(rslt_dir + "/dynamics/dynamics_a.jpg")
    plt.close()

    plot_dynamics(weighted_corner_vecs_b,
                  "mother",
                  x_grids,
                  y_grids,
                  K=K,
                  scale=quiver_scale,
                  percentage=grid_z_b_percentage,
                  title="grid dynamics (mother)")
    plt.savefig(rslt_dir + "/dynamics/dynamics_b.jpg")
    plt.close()

    plot_quiver(XY_grids[:, 0:2],
                dXY[..., 0:2],
                'virgin',
                K=K,
                scale=quiver_scale,
                alpha=0.9,
                title="quiver (virgin)")
    plt.savefig(rslt_dir + "/dynamics/quiver_a.jpg")
    plt.close()

    plot_quiver(XY_grids[:, 2:4],
                dXY[..., 2:4],
                'mother',
                K=K,
                scale=quiver_scale,
                alpha=0.9,
                title="quiver (mother)")
    plt.savefig(rslt_dir + "/dynamics/quiver_b.jpg")
    plt.close()

    if not os.path.exists(rslt_dir + "/distributions"):
        os.makedirs(rslt_dir + "/distributions")
        print("Making distributions directory...")

    data_angles_a, data_angles_b = get_all_angles(data, x_grids, y_grids)
    sample_angles_a, sample_angles_b = get_all_angles(sample_x, x_grids,
                                                      y_grids)
    sample_x_center_angles_a, sample_x_center_angles_b = get_all_angles(
        sample_x_center, x_grids, y_grids)

    plot_list_of_angles(
        [data_angles_a, sample_angles_a, sample_x_center_angles_a],
        ['data', 'sample', 'sample_c'], "direction distribution (virgin)", n_x,
        n_y)
    plt.savefig(rslt_dir + "/distributions/angles_a.jpg")
    plt.close()
    plot_list_of_angles(
        [data_angles_b, sample_angles_b, sample_x_center_angles_b],
        ['data', 'sample', 'sample_c'], "direction distribution (mother)", n_x,
        n_y)
    plt.savefig(rslt_dir + "/distributions/angles_b.jpg")
    plt.close()

    data_speed_a, data_speed_b = get_speed(data, x_grids, y_grids)
    sample_speed_a, sample_speed_b = get_speed(sample_x, x_grids, y_grids)
    sample_x_center_speed_a, sample_x_center_speed_b = get_speed(
        sample_x_center, x_grids, y_grids)

    plot_list_of_speed([data_speed_a, sample_speed_a, sample_x_center_speed_a],
                       ['data', 'sample', 'sample_c'],
                       "speed distribution (virgin)", n_x, n_y)
    plt.savefig(rslt_dir + "/distributions/speed_a.jpg")
    plt.close()
    plot_list_of_speed([data_speed_b, sample_speed_b, sample_x_center_speed_b],
                       ['data', 'sample', 'sample_c'],
                       "speed distribution (mother)", n_x, n_y)
    plt.savefig(rslt_dir + "/distributions/speed_b.jpg")
    plt.close()

    try:
        if 100 < data.shape[0] <= 36000:
            plot_space_dist(data, x_grids, y_grids)
        elif data.shape[0] > 36000:
            plot_space_dist(data[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_data.jpg")
        plt.close()

        if 100 < sample_x.shape[0] <= 36000:
            plot_space_dist(sample_x, x_grids, y_grids)
        elif sample_x.shape[0] > 36000:
            plot_space_dist(sample_x[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_sample_x.jpg")
        plt.close()

        if 100 < sample_x_center.shape[0] <= 36000:
            plot_space_dist(sample_x_center, x_grids, y_grids)
        elif sample_x_center.shape[0] > 36000:
            plot_space_dist(sample_x_center[:36000], x_grids, y_grids)
        plt.savefig(rslt_dir + "/distributions/space_sample_x_center.jpg")
        plt.close()
    except:
        print("plot_space_dist unsuccessful")