/
m-h.py
124 lines (89 loc) · 2.81 KB
/
m-h.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#an algorithm which takes as an argument any probability density function (not necessarily normalized) and returns a collection of samples from the normalized distribution
import numpy as np
import pylab as plt
import scipy
from scipy.stats import norm
from scipy.stats import expon
gauss = lambda x: 10*norm.pdf(x)
sumgauss = lambda x: 10*norm.pdf(x) + 15*norm.pdf(x,loc = 20, scale = 4)
exp = lambda x: 10*expon.pdf(x)
exgauss = lambda x: 10*expon.pdf(x) + 15*norm.pdf(x,loc = 20, scale = 4)
def metrop(dist,n=1000,full_out=False):
#TODO look up how to initialize markov-chain?
init = 10
sigma = 1
samples = []
accepted = 0.
logdis = lambda x: np.log(dist(x))
count = 0
while True:
x_prime = np.random.normal(init,5)
a = dist(x_prime)/dist(init)
if logdis(x_prime)>=logdis(init):
init = x_prime
count+=1
else:
i = np.random.rand()
if i<=a:
init = x_prime
count+=1
else:
pass
if count>100:
break
samples+=[x_prime]
for u in range(n):
x_t = samples[u]
#we will use a Gaussian centered at xt for Q
x_prime = np.random.normal(x_t,5)
a = dist(x_prime)/dist(x_t)
#print a
if logdis(x_prime)>=logdis(x_t):
samples+=[x_prime]
accepted+=1.
else:
i = np.random.rand()
if i<=a:
samples+=[x_prime]
accepted+=1.
else:
samples+=[x_t]
info = accepted/1000
if full_out==True:
return samples,info
else:
return samples
def kldiv(samples,dist):
#attempt to compute k-l divergence
#d(P|Q) = \int(p(x)*log(p(x)/q(x))dx)
#approximate by riemann sum
samples = np.array(samples)
n = 1000
x = np.linspace(samples.min(),samples.max(),n)
px = np.array([dist(xx) for xx in x[:n-1]])
#approximate q(x) by (number of samples in (x,x+1))/(len(samples)
sampmatrix = samples*np.ones((n,1))
qx = [[x[u]<=a<x[u+1] for a in sampmatrix[u]] for u in range(len(x)-1)]
qx = np.array(qx).sum(axis=1)/(1.*len(samples))
#nonzeroq = qx.nonzero()[0]
#qx = qx[nonzeroq]
#px = px[nonzeroq]
for i in range(1,len(qx)):
if not(qx[i]==0):
pass
else:
qx[i] = qx[i-1]
l = (samples.max()-samples.min())/len(qx)
nonzerop = px.nonzero()[0]
qx = qx[nonzerop]
px = l*px[nonzerop]
evals = px*np.log(px/qx)
dpq = np.sum(evals)
return dpq
#def plotkl(dist):
# d100 = kldiv(metrop(dist,100),dist)
# d1000 = kldiv(metrop(dist,1000),dist)
# d10000 = kldiv(metrop(dist,10000),dist)
# plt.plot([d100,d1000,d10000])
# plt.show()
# return [d100,d1000,d10000]