def testMult(self): otherpot = Potential() otherpot.variables = np.array([3]) otherpot.card = np.array([3]) otherpot.table = np.array([0.1, 0.4, 0.5]) answerpot = Potential() answerpot.variables = np.array([1, 2, 3]) answerpot.card = np.array([2, 2, 3]) answerpot.table = np.array([[[0.02, 0.08, 0.1], [0.08, 0.32, 0.4]], [[0.06, 0.24, 0.3], [0.04, 0.16, 0.2]]]) self.assertTwoPot(self.pot * otherpot, answerpot) otherpot = Potential() otherpot.variables = np.array([3, 1]) otherpot.card = np.array([3, 2]) otherpot.table = np.array([[0.2, 0.3], [0.2, 0.2], [0.6, 0.5]]) answerpot = Potential() answerpot.variables = np.array([1, 2, 3]) answerpot.card = np.array([2, 2, 3]) answerpot.table = np.array([[[0.04, 0.04, 0.12], [0.16, 0.16, 0.48]], [[0.18, 0.12, 0.3], [0.12, 0.08, 0.2]]]) self.assertTwoPot(self.pot * otherpot, answerpot)
def setpot(pot, evvariables, evidstates): #FIXME: data format needed to be unified vars = pot.variables #vars = np.array(pot.variables) # convert to ndarray format #evariables = np.array(evvariables) # convert to ndarray format #evidstates = np.array(evidstates) # convert to ndarray format #print("variables:", vars) table = pot.table nstates = pot.card #print("number of states:", nstates) #print("vars:", vars) #print("evvariables:", evvariables) intersection, iv, iev = intersect(vars, evvariables) #iv = np.array(iv) #iev = np.array(iev) #print("intersection:", intersection) #print("iv:", iv) #print("iev:", iev) #print("iv type:", type(iv)) #print("number of intersection:", intersection.size) if intersection.size == 0: newpot = copy.copy(pot) else: newvar = setminus(vars, intersection) dummy, idx = ismember(newvar, vars) newns = nstates[idx] newpot = Potential() newpot.variables = newvar newpot.card = newns newpot.table = np.zeros(newns) #print("idx:", idx) #print("iv:", iv) for i in range(np.prod(newns)): newassign = index_to_assignment(i, newns) oldassign = np.zeros(nstates.size, 'int8') oldassign[idx] = newassign oldassign[iv] = evidstates #print("newpot.table.shape:", newpot.table.shape) #print("newassign:", newassign) #print("newassign type:", type(newassign)) newpot.table[tuple(newassign)] = pot.table[tuple(oldassign)] return newpot