def _get_action_dist(self): self._set_fixed() N,K,T,S,A,X,t = self.N,self.K,self.T,self.S,self.A,self.X,self.t Tr,L,D,pS0,pA0 = self.Tr,self.L,self.D,self.pS0,self.pA0 ## inference if t==0: pSt = (L[0]*pS0).normalize([S[0]]) else: pSt0 = pr.sum([S[t-1]],Tr[t-1]*self.memory) pSt = (L[t]*pSt0).normalize([S[t]]) self.memory = pSt ## action distribution pred = pSt # placeholder for previous predictive state distributions Q = 0 # placeholder for value function for tau in range(t+1,T+1): pred = pr.sum([S[tau-1]],Tr[tau-1]*pred) Q = Q + ( pr.sum([X[tau],S[tau]],(pr.log(L[tau])+pr.log(D[tau]))*L[tau]*pred) - pr.sum([X[tau],S[tau]],pr.log(pr.sum([S[tau]],L[tau]*pred))*L[tau]*pred)) pA = (pA0*pr.exp_tr(self.alpha*Q)).normalize() return pr.sum(pA,[A[t]]).val # return marginal of pA with respect to At
def meanField(self): precision=1e-4 N,K,T,S,A,X,t = self.N,self.K,self.T,self.S,self.A,self.X,self.t Tr,L,D,pS0,pA0 = self.Tr,self.L,self.D,self.pS0,self.pA0 ### initialize qA = pr.func(vars=list(A.values()), val='unif').normalize() qS = {**{i: pr.func(vars=[S[i],*[A[j] for j in range(t,T)]], val='unif').normalize([S[i]]) for i in range(0,t+1)}, **{i: pr.func(vars=[S[i],*[A[j] for j in range(t,T)]],val='unif').normalize([S[i]]) for i in range(t+1,T+1)}} ## repeat: q0 = 0 for iter in range(self.iterations): qS[0] = (L[0]*pS0*pr.exp_tr(pr.sum([S[1]],pr.log(Tr[0])*qS[1]))).normalize([S[0]]) if t > 0: for tau in range(1,t): qS[tau] = (L[tau]*pr.exp_tr(pr.sum([S[tau-1]],pr.log(Tr[tau-1])*qS[tau-1])) +pr.sum([S[tau+1]],pr.log(Tr[tau])*qS[tau+1])).normalize([S[tau]]) qS[t] = (L[t]*pr.exp_tr(pr.sum([S[t-1]],pr.log(Tr[t-1])*qS[t-1])) +pr.sum([S[t+1]],pr.log(Tr[t])*qS[t+1])).normalize([S[t]]) for tau in range(t+1,T): qS[tau] = (pr.exp_tr(pr.sum([S[tau-1]],pr.log(Tr[tau-1])*qS[tau-1]) + pr.sum([S[tau+1]],pr.log(Tr[tau])*qS[tau+1]))).normalize([S[tau]]) qS[T] = pr.exp_tr(pr.sum([S[T-1]],pr.log(Tr[T-1])*qS[T-1])).normalize([S[T]]) G = 0 for tau in range(t+1,T+1): G = G + ( pr.sum([S[tau-1],S[tau]],pr.log(Tr[tau-1])*qS[tau]*qS[tau-1]) - pr.sum([S[tau]],pr.log(qS[tau])*qS[tau]) + pr.sum([X[tau],S[tau]],(pr.log(L[tau])+pr.log(D[tau]))*L[tau]*qS[tau]) - pr.sum([X[tau],S[tau]],pr.log(pr.sum([S[tau]],L[tau]*qS[tau]))*L[tau]*qS[tau])) G = G + ( pr.sum([S[0]],pr.log(pS0)*qS[0]) - pr.sum([S[0]],pr.log(qS[0])*qS[0]) + pr.sum([S[0]],pr.log(L[0])*qS[0])) for tau in range(1,t+1): G = G + ( pr.sum([S[tau-1],S[tau]],pr.log(Tr[tau-1])*qS[tau]*qS[tau-1]) - pr.sum([S[tau]],pr.log(qS[tau])*qS[tau]) + pr.sum([S[tau]],pr.log(L[tau])*qS[tau]) ) qA = (pA0*pr.exp_tr(self.alpha*G)).normalize() qAmarg = pr.sum(qA,[A[t]]).val if np.linalg.norm(qAmarg-q0) < precision: break q0 = qAmarg return qAmarg
def meanField(self): precision=1e-3 N,K,T,S,A,X,t = self.N,self.K,self.T,self.S,self.A,self.X,self.t Tr,L,D,pS0,pA0 = self.Tr,self.L,self.D,self.pS0,self.pA0 ### initialize qA = pr.func(vars=list(A.values()), val='unif').normalize() qS = {**{i: pr.func(vars=[S[i]], val='unif').normalize([S[i]]) for i in range(0,t+1)}, **{i: pr.func(vars=[S[i],*[A[j] for j in range(t,i)]],val='unif').normalize([S[i]]) for i in range(t+1,T+1)}} ## repeat: q0 = 0 for iter in range(self.iterations): # PAST AND CURRENT STATES if t > 0: qS[0] = (L[0]*pS0*pr.exp_tr(pr.sum([S[1]],pr.log(Tr[0])*qS[1]))).normalize([S[0]]) for tau in range(1,t): qS[tau] = (L[tau]*pr.exp_tr(pr.sum([S[tau-1]],pr.log(Tr[tau-1])*qS[tau-1])) +pr.sum([S[tau+1]],pr.log(Tr[tau])*qS[tau+1])).normalize([S[tau]]) qS[t] = (L[t]*pr.exp_tr(pr.sum([S[t-1]],pr.log(Tr[t-1])*qS[t-1])) +pr.sum([S[t+1],A[t]],pr.log(Tr[t])*qS[t+1]*pr.sum(qA,[A[t]]))).normalize([S[t]]) else: qS[0] = (L[0]*pS0*pr.exp_tr(pr.sum([S[1],A[0]],pr.log(Tr[0])*qS[1]*pr.sum(qA,[A[0]])))).normalize([S[0]]) # FUTURE STATES --- replaced by gradient descent --------- for tau in range(t+1,T): qp = qS[tau+1].val qm = qS[tau-1].val qS[tau].val = get_qS_GD(t,tau,T,N,K,self.pTrans,self.likelihood,self.pDesired,qm,qp,qA.val) qS[T].val = get_qS_GD(t,T,T,N,K,self.pTrans,self.likelihood,self.pDesired,qS[T-1].val,[],qA.val) #---------------------------------------------------------- # ACTION G = 0 for tau in range(t+1,T+1): G = G + ( pr.sum([S[tau-1],S[tau]],pr.log(Tr[tau-1])*qS[tau]*qS[tau-1]) - pr.sum([S[tau]],pr.log(qS[tau])*qS[tau]) + pr.sum([X[tau],S[tau]],(pr.log(L[tau])+pr.log(D[tau]))*L[tau]*qS[tau]) - pr.sum([X[tau],S[tau]],pr.log(pr.sum([S[tau]],L[tau]*qS[tau]))*L[tau]*qS[tau])) qA = (pA0*pr.exp_tr(self.alpha*G)).normalize() qAmarg = pr.sum(qA,[A[t]]).val if np.linalg.norm(qAmarg-q0) < precision: break q0 = qAmarg return qAmarg
def meanField(self): N,K,T,S,A,X,t = self.N,self.K,self.T,self.S,self.A,self.X,self.t Tr,L,D,pS0,pA0 = self.Tr,self.L,self.D,self.pS0,self.pA0 ### initialize qA = pr.func(vars=list(A.values()), val='unif').normalize() qS = {**{i: pr.func(vars=[S[i]], val='unif').normalize([S[i]]) for i in range(0,t+1)}, **{i: pr.func(vars=[S[i],*[A[j] for j in range(t,i)]],val='unif').normalize([S[i]]) for i in range(t+1,T+1)}} ## repeat: for iter in range(self.iterations): if t > 0: qS[0] = (L[0]*pS0*pr.exp_tr(pr.sum([S[1]],pr.log(Tr[0])*qS[1]))).normalize([S[0]]) for tau in range(1,t): qS[tau] = (L[tau]*pr.exp_tr(pr.sum([S[tau-1]],pr.log(Tr[tau-1])*qS[tau-1])) +pr.sum([S[tau+1]],pr.log(Tr[tau])*qS[tau+1])).normalize([S[tau]]) qS[t] = (L[t]*pr.exp_tr(pr.sum([S[t-1]],pr.log(Tr[t-1])*qS[t-1])) +pr.sum([S[t+1],A[t]],pr.log(Tr[t])*qS[t+1]*pr.sum(qA,[A[t]]))).normalize([S[t]]) else: qS[0] = (L[0]*pS0*pr.exp_tr(pr.sum([S[1],A[0]],pr.log(Tr[0])*qS[1]*pr.sum(qA,[A[0]])))).normalize([S[0]]) for tau in range(t+1,T): qS[tau] = (pr.exp_tr(pr.sum([S[tau-1]],pr.log(Tr[tau-1])*qS[tau-1]) + pr.sum([S[tau+1],A[tau]],pr.log(Tr[tau])*qS[tau+1]*pr.sum(qA,[A[tau]])) + pr.sum([X[tau]],pr.log(D[tau])*L[tau]) )).normalize([S[tau]]) qS[T] = (pr.exp_tr(pr.sum([S[T-1]],pr.log(Tr[T-1])*qS[T-1]) + pr.sum([X[T]],pr.log(D[T])*L[T]) )).normalize([S[T]]) G = 0 for tau in range(t+1,T+1): G = G + ( pr.sum([S[tau-1],S[tau]],pr.log(Tr[tau-1])*qS[tau]*qS[tau-1]) - pr.sum([S[tau]],pr.log(qS[tau])*qS[tau]) + pr.sum([X[tau],S[tau]],pr.log(D[tau])) ) qA = (pA0*pr.exp_tr(self.alpha*G)).normalize() return pr.sum(qA,[A[t]]).val