def solve(system,initV = None, gamma = 0.9): numNodes = system.network.numNodes numTrt = Agent.numTrt(system) numValidTrt = Agent.numValidTrt(numNodes,numTrt) if initV is None: initV = np.zeros((1 << numNodes,)) it = 0 maxIt = 1000 tol = 1e-8 cont = True v0 = initV while cont: v1 = ValueIteration.operT(system,gamma,v0) it += 1 if np.linalg.norm(v1 - v0,2) < tol or it == maxIt: cont = False v0 = v1 if it == maxIt: raise ValueError("ValueIteration hit iteration limit") return v0
def operT(system,gamma,v): numNodes = system.network.numNodes numTrt = Agent.numTrt(system) numValidTrt = Agent.numValidTrt(numNodes,numTrt) vForA = np.zeros((1 << numNodes, numValidTrt)) for aInd in range(numValidTrt): P,R = ValueIteration.calcPAndR(system,aInd) vForA[:,aInd] = (R + gamma * (P.dot(v))) return np.amax(vForA,1)
def unitTest(cls): print "Testing ValueIteration" np.random.seed(0) from system import System from networks import genGridNetwork from model import PJ system = System(genGridNetwork((3,3)),PJ()) numNodes = system.network.numNodes numTrt = Agent.numTrt(system) numValidTrt = Agent.numValidTrt(numNodes,numTrt) v = ValueIteration.solve(dc(system)) q = PolicyIteration.solve(dc(system)) q = util.unflattenQ(q,numNodes,numValidTrt) vChk = [max(i) for i in q] for i in zip(v,vChk): print "% 12.6f % 10.6f" % i
def calcPAndR(system,trtInd): numNodes = system.network.numNodes numTrt = Agent.numTrt(system) numValidTrt = Agent.numValidTrt(numNodes,numTrt) P = np.zeros((1 << numNodes,1 << numNodes)) R = np.zeros((1 << numNodes,)) trtCmb = util.ind2Combo(trtInd,numNodes,numTrt) system.trtCmb(cmb = trtCmb) for s in range(1 << numNodes): system.infCmb(cmb = s) probs = system.model.transProbs(system) for sp in range(1 << numNodes): changes = s ^ sp prob = 0.0 ind = 1 for i in range(numNodes): if changes & ind: if probs[i] < 1e-13: prob += -30 else: prob += np.log(probs[i]) else: if 1.0 - probs[i] < 1e-13: prob += -30 else: prob += np.log(1.0 - probs[i]) ind <<= 1 prob = np.exp(prob) P[s,sp] = prob r = reward(s,trtCmb,sp,numNodes) R[s] += prob * r return P,R