def barplot_rating_dist(item, single=False, group=None, savefig=None):

	with msg("plotting rating distribution"):
		ratings = Data.get_ratings()[:,item]
		nyms = Data.get_nyms()

		plt.xlabel('rating')
		plt.ylabel('no. ratings')
		step = 1
		bins = np.arange(step/2, 5 + 1.5*step, step)
		hist = lambda d, **kwargs: plt.hist(d, bins=bins, rwidth=step*0.75, **kwargs)
		if group is not None: 
			plt.title(f'Item {item}, group {group} rating distribution')
			hist(ratings[nyms[group]].data)
		elif single: 
			plt.title(f'Item {item} rating distribution')
			hist(ratings.data)
		else:
			plt.title(f'Item {item}, all groups rating distributions')
			for nym_n, nym in enumerate(nyms):
				hist(ratings[nym].data, histtype='step', linewidth=2 ,label=f'group {nym_n}')
			plt.legend()
		if savefig is None:
			plt.show()
		else:
			with msg(f'Saving figure to "{savefig}"'):
				plt.savefig(savefig, dpi=150)
			plt.clf()
def heatmap_rating_dist(item):
	# def plot_rating_dists_across_groups(ratings, item, groups, savefig=False):
	with msg("plotting rating distribution"):
		ratings = Data.get_ratings()[:,item]
		nyms = Data.get_nyms()

		data = np.zeros((10, len(nyms)))
		for nym_n, nym in enumerate(nyms):
			unique, count = np.unique(ratings[nym].data, return_counts=True)
			for rating, count in dict(zip(unique, count)).items():
				data[int(2*rating - 1), nym_n] = count

		ax = sns.heatmap(data)
		ax.set(
			title="Distribution of item #{} ratings by group".format(int(item)),
			xlabel="group number", 
			ylabel="rating", 
			yticklabels=np.linspace(0.5, 5, 10))
		
		plt.show()
import numpy as np
import matplotlib.pyplot as plt

from myutils import msg
from datareader import DataReader
from dist_model import DiscreteNormal as DiscNorm

rating_count = 5
dist_gen = DiscNorm(np.linspace(0.5, 5.5, num=rating_count + 1))

with msg("Getting data"):
    Rtilde = DataReader.get_Rtilde()
    Rvar = DataReader.get_Rvar()
    R = DataReader.get_ratings()
    lam = DataReader.get_lam()
    P = DataReader.get_nyms()


def get_data_dist(data):
    ratings, counts = np.unique(data, return_counts=True)
    dist_data = np.zeros(rating_count)
    dist_data[ratings.astype(int) - 1] = counts / counts.sum()
    return dist_data


def get_err(data, mean, var):
    dist_data = get_data_dist(data)
    dist_model = dist_gen.pmf(mean, var)
    return abs(dist_data / dist_model)