def display_model(training_data, alpha, beta, gamma, delta): """ Displays a plot of the data under the maximum a posterior distribution, along side the maximum likelihood for the data and the prior. :param training_data: The training data to fit the distribution to. :type training_data: list[float] :param alpha: The alpha hyperparameter of the prior :type alpha: float :param beta: The beta hyperparameter of the prior :type beta: float :param gamma: The gamma hyperparameter of the prior :type gamma: float :param delta: The delta hyperparameter of the prior :type delta: float """ prior_mean, prior_variance = generate_parameters([], alpha, beta, gamma, delta) ml_mean, ml_variance = generate_maximum_likelihood_parameters(training_data) mean, variance = generate_parameters(training_data, alpha, beta, gamma, delta) prior_sigma = prior_variance ** 0.5 ml_sigma = ml_variance ** 0.5 sigma = variance ** 0.5 map_color = sns.color_palette()[0] data_color = sns.color_palette()[1] prior_color = sns.color_palette()[2] ml_color = sns.color_palette()[3] display.draw_normal(prior_mean, prior_sigma, color=prior_color) display.draw_normal(ml_mean, ml_sigma, color=ml_color) display.draw_data_under_normal(training_data, mean, sigma, data_color=data_color, normal_color=map_color) plt.show()
def test_draw_data_under_normal(self, mock_plt, mock_draw_normal): data = [0.9, 1.0] mean = 1 sigma = 1 draw_data_under_normal(data, mean, sigma) assert (([0.9, 0.9], [0.0, scipy.stats.norm.pdf(0.9, mean, sigma)]),) in mock_plt.plot.call_args_list assert (([1.0, 1.0], [0.0, scipy.stats.norm.pdf(1.0, mean, sigma)]),) in mock_plt.plot.call_args_list assert mock_draw_normal.call_args == ((mean, sigma), {'color': None})
def display_model(training_data): """ Displays a plot of the maximum likelihood distribution along with the input data. :param training_data: The training data to fit the distribution to. :type training_data: list[float] """ mean, variance = generate_parameters(training_data) sigma = variance ** 0.5 normal_color = sns.color_palette()[0] data_color = sns.color_palette()[1] display.draw_data_under_normal(training_data, mean, sigma, data_color=data_color, normal_color=normal_color) plt.show()
def test_draw_normal_can_be_passed_a_color_for_the_normal_and_data(self, mock_norm, mock_draw_normal, mock_plt): draw_data_under_normal([0], 0, 0, data_color='r', normal_color='b') assert mock_draw_normal.call_args[1]['color'] == 'b' assert mock_plt.plot.call_args[1]['color'] == 'r'