def display_model(training_data, alpha, beta, gamma, delta):
    """
    Displays a plot of the data under the maximum a posterior distribution, along side the maximum likelihood for
    the data and the prior.

    :param training_data: The training data to fit the distribution to.
    :type training_data: list[float]
    :param alpha: The alpha hyperparameter of the prior
    :type alpha: float
    :param beta: The beta hyperparameter of the prior
    :type beta: float
    :param gamma: The gamma hyperparameter of the prior
    :type gamma: float
    :param delta: The delta hyperparameter of the prior
    :type delta: float
    """
    prior_mean, prior_variance = generate_parameters([], alpha, beta, gamma, delta)
    ml_mean, ml_variance = generate_maximum_likelihood_parameters(training_data)
    mean, variance = generate_parameters(training_data, alpha, beta, gamma, delta)

    prior_sigma = prior_variance ** 0.5
    ml_sigma = ml_variance ** 0.5
    sigma = variance ** 0.5

    map_color = sns.color_palette()[0]
    data_color = sns.color_palette()[1]
    prior_color = sns.color_palette()[2]
    ml_color = sns.color_palette()[3]

    display.draw_normal(prior_mean, prior_sigma, color=prior_color)
    display.draw_normal(ml_mean, ml_sigma, color=ml_color)
    display.draw_data_under_normal(training_data, mean, sigma, data_color=data_color, normal_color=map_color)
    plt.show()
예제 #2
0
    def test_draw_normal(self, mock_plt):
        mean = 1
        sigma = 1
        plotting_space = np.linspace(scipy.stats.norm.ppf(0.01, mean, sigma),
                                     scipy.stats.norm.ppf(0.99, mean, sigma),
                                     constant.plot_samples)
        distribution = scipy.stats.norm.pdf(plotting_space, mean, sigma)

        draw_normal(mean, sigma)

        assert np.array_equal((plotting_space, distribution), mock_plt.plot.call_args[0])
예제 #3
0
    def test_draw_normal_can_be_passed_a_color_for_the_normal(self, mock_norm, mock_linspace, mock_plt):
        draw_normal(0, 0, color='r')

        assert mock_plt.plot.call_args[1]['color'] == 'r'