forked from derkjan12/master_thesis
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dist_utils.py
75 lines (63 loc) · 2.72 KB
/
dist_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import distribution_sampler as sampler
import numpy as np
import dit, string, itertools
from typing import List
import pickle
import os
def get_vars(n_vars: int) -> List[str]:
vs = ['X{}'.format(i) for i in range(n_vars-1)]
vs.append('Y')
return vs
def get_labels(n_vars: int, n_states: int) -> List[str]:
if n_states < 1 or n_states > 10:
raise ValueError("states should be greater than 0 and less than or equal to 10")
return [''.join(i) for i in itertools.product(string.digits[:n_states], repeat=n_vars)]
def generate_distribution(n_vars: int, n_states: int, entropy_level: float, base=np.e) -> dit.Distribution:
var_names = get_vars(n_vars)
state_labels = get_labels(n_vars, n_states)
pmf = sampler.sample(n_states**n_vars, level=entropy_level)
if base == np.e:
pmf = np.log(pmf)
d = dit.Distribution(state_labels, pmf=pmf, base=base)
d.set_rv_names(var_names)
return d
def get_marginals(d: dit.Distribution) -> (dit.Distribution, List[dit.Distribution]):
rvs = d.get_rv_names()[:-1] #everything except the output
return d.condition_on(rvs)
def get_joint(X: dit.Distribution, YgX: List[dit.Distribution], base=np.e) -> dit.Distribution:
return dit.joint_from_factors(X, YgX).copy(base=base)
def print_conditional(YgX):
for i, Y in enumerate(YgX):
Y.make_dense()
print("{}: ".format(i), Y.pmf)
def load_dist(model, parameter, n_vars,distribution_number, corrected=False):
if model == "ising":
folder = "corrected_ising_distributions" if corrected else "uncorrected_ising_distributions"
p_folder = "temp{:.2f}".format(parameter)
if model == "sis":
folder = "corrected_sis_distributions" if corrected else "uncorrected_sis_distributions"
beta, gamma, alpha = parameter
p_folder = "beta{:.2f}".format(beta)
distribution_number = str(distribution_number) + "_alpha{:.2f}_gamma{:.2f}".format(alpha,gamma)
dist_file = "{}/{}/n{}/d{}.pkl".format(folder,p_folder,n_vars, distribution_number)
if os.path.exists(dist_file):
with open(dist_file, "rb") as f:
dist = pickle.load(f)
return dist
else:
return None
def calculate_XY(X, YgX, model="ising"):
X.make_dense()
a = X.alphabet
outcomes = list(itertools.product(a, repeat=X.outcome_length()+1))
XY_pmf = X.pmf[:,np.newaxis]*YgX
XYd = {o:p for o, p in zip(outcomes,XY_pmf)}
XY = dit.ScalarDistribution(XYd)
return XY
def calculate_Y(X, YgX, model="ising"):
X.make_dense()
Y_pmf = (X.pmf[:,np.newaxis]*YgX).sum(axis=0)
outcomes = [-1,1] if model == "ising" else [0,1]
Yd = {o:p for o, p in zip(outcomes,Y_pmf)}
Y = dit.ScalarDistribution(Yd)
return Y