def applyPSP(trace, node, scaffold, shouldRestore, omegaDB, gradients): weight = 0 psp, args = trace.pspAt(node), trace.argsAt(node) assert isinstance(psp, PSP) if omegaDB.hasValueFor(node): oldValue = omegaDB.getValue(node) else: oldValue = None if scaffold.hasLKernel(node): k = scaffold.getLKernel(node) if not shouldRestore: newValue = k.forwardSimulate(trace, oldValue, args) else: newValue = oldValue weight += k.forwardWeight(trace, newValue, oldValue, args) check_kernel_weight(weight, k, newValue, oldValue, args) weight = ensure_python_float(weight) if isinstance(k, VariationalLKernel): gradients[node] = k.gradientOfLogDensity(newValue, args) else: # if we simulate from the prior, the weight is 0 newValue = psp.simulate(args) if not shouldRestore else oldValue trace.setValueAt(node, newValue) psp.incorporate(newValue, args) if isinstance(newValue, VentureSPRecord): processMadeSP(trace, node, scaffold.isAAA(node)) if psp.isRandom(): trace.registerRandomChoice(node) maybeRegisterRandomChoiceInScope(trace, node) return ensure_python_float(weight)
def restore(trace, node, scaffold, omegaDB, gradients): if isConstantNode(node): return 0 if isLookupNode(node): weight = regenParents(trace, node, scaffold, True, omegaDB, gradients) trace.reconnectLookup(node) trace.setValueAt(node, trace.valueAt(node.sourceNode)) return ensure_python_float(weight) else: # node is output node assert isOutputNode(node) weight = restore(trace, node.operatorNode, scaffold, omegaDB, gradients) for operandNode in node.operandNodes: weight += restore(trace, operandNode, scaffold, omegaDB, gradients) weight += apply(trace, node.requestNode, node, scaffold, True, omegaDB, gradients) return ensure_python_float(weight)
def regenParents(trace, node, scaffold, shouldRestore, omegaDB, gradients): weight = 0 for parent in trace.definiteParentsAt(node): weight += regen(trace, parent, scaffold, shouldRestore, omegaDB, gradients) for parent in trace.esrParentsAt(node): weight += regen(trace, parent, scaffold, shouldRestore, omegaDB, gradients) return ensure_python_float(weight)
def absorb(trace, node): psp, args = trace.pspAt(node), trace.argsAt(node) gvalue = trace.groundValueAt(node) weight = psp.logDensity(gvalue, args) check_weight(weight, psp, args) psp.incorporate(gvalue, args) maybeRegisterRandomChoiceInScope(trace, node) return ensure_python_float(weight)
def constrain(trace, node, value): psp, args = trace.pspAt(node), trace.argsAt(node) psp.unincorporate(trace.valueAt(node), args) weight = psp.logDensity(value, args) check_weight(weight, psp, args) trace.setValueAt(node, value) psp.incorporate(value, args) trace.registerConstrainedChoice(node) return ensure_python_float(weight)
def apply(trace, requestNode, outputNode, scaffold, shouldRestore, omegaDB, gradients): weight = applyPSP(trace, requestNode, scaffold, shouldRestore, omegaDB, gradients) weight += evalRequests(trace, requestNode, scaffold, shouldRestore, omegaDB, gradients) assert len(trace.esrParentsAt(outputNode)) == \ len(trace.valueAt(requestNode).esrs) weight += regenESRParents(trace, outputNode, scaffold, shouldRestore, omegaDB, gradients) weight += applyPSP(trace, outputNode, scaffold, shouldRestore, omegaDB, gradients) return ensure_python_float(weight)
def evalFamily(trace, address, exp, env, scaffold, shouldRestore, omegaDB, gradients): if e.isVariable(exp): try: sourceNode = env.findSymbol(exp) except VentureError as err: import sys info = sys.exc_info() raise VentureException("evaluation", err.message, address=address), \ None, info[2] weight = regen(trace, sourceNode, scaffold, shouldRestore, omegaDB, gradients) return (weight, trace.createLookupNode(address, sourceNode)) elif e.isSelfEvaluating(exp): return (0, trace.createConstantNode(address, exp)) elif e.isQuotation(exp): return (0, trace.createConstantNode(address, e.textOfQuotation(exp))) else: weight = 0 nodes = [] for index, subexp in enumerate(exp): new_address = addr.extend(address, index) w, n = evalFamily(trace, new_address, subexp, env, scaffold, shouldRestore, omegaDB, gradients) weight += w nodes.append(n) (requestNode, outputNode) = \ trace.createApplicationNodes(address, nodes[0], nodes[1:], env) try: weight += apply(trace, requestNode, outputNode, scaffold, shouldRestore, omegaDB, gradients) except VentureNestedRiplMethodError as err: # This is a hack to allow errors raised by inference SP actions # that are ripl actions to blame the address of the maker of the # action rather than the current address, which is the # application of that action (which is where the mistake is # detected). import sys info = sys.exc_info() raise VentureException("evaluation", err.message, address=err.addr, cause=err), None, info[2] except VentureException: raise # Avoid rewrapping with the below except Exception as err: import sys info = sys.exc_info() raise VentureException("evaluation", err.message, address=address, cause=err), None, info[2] return ensure_python_float(weight), outputNode
def regenAndAttachAtBorder(trace, border, scaffold, shouldRestore, omegaDB, gradients): weight = 0 constraintsToPropagate = OrderedDict() for node in border: # print "regenAndAttach...", node if scaffold.isAbsorbing(node): weight += attach(trace, node, scaffold, shouldRestore, omegaDB, gradients) else: weight += regen(trace, node, scaffold, shouldRestore, omegaDB, gradients) if node.isObservation: weight += getAndConstrain(trace, node, constraintsToPropagate) propagateConstraints(trace, constraintsToPropagate) return ensure_python_float(weight)
def regen(trace, node, scaffold, shouldRestore, omegaDB, gradients): weight = 0 if scaffold.isResampling(node): if trace.regenCountAt(scaffold, node) == 0: weight += regenParents(trace, node, scaffold, shouldRestore, omegaDB, gradients) if isLookupNode(node): propagateLookup(trace, node) else: weight += applyPSP(trace, node, scaffold, shouldRestore, omegaDB, gradients) if isRequestNode(node): weight += evalRequests(trace, node, scaffold, shouldRestore, omegaDB, gradients) trace.incRegenCountAt(scaffold, node) weight += maybeRegenStaleAAA(trace, node, scaffold, shouldRestore, omegaDB, gradients) return ensure_python_float(weight)
def evalRequests(trace, node, scaffold, shouldRestore, omegaDB, gradients): assert isRequestNode(node) weight = 0 request = trace.valueAt(node) # first evaluate exposed simulation requests (ESRs) for esr in request.esrs: if not trace.containsSPFamilyAt(node, esr.id): if shouldRestore and omegaDB.hasESRParent(trace.spAt(node), esr.id): esrParent = omegaDB.getESRParent(trace.spAt(node), esr.id) weight += restore(trace, esrParent, scaffold, omegaDB, gradients) else: address = addr.request(node.address, esr.addr) (w, esrParent) = evalFamily(trace, address, esr.exp, esr.env, scaffold, shouldRestore, omegaDB, gradients) weight += w if trace.containsSPFamilyAt(node, esr.id): # evalFamily already registered a family with this id for the # operator being applied here, which means a recursive call to # the operator issued a request for the same id. Currently, # the only way for that it happen is for a recursive memmed # function to call itself with the same arguments. raise VentureException("evaluation", "Recursive mem argument " \ "loop detected.", address=node.address) trace.registerFamilyAt(node, esr.id, esrParent) esrParent = trace.spFamilyAt(node, esr.id) trace.addESREdge(esrParent, node.outputNode) # next evaluate latent simulation requests (LSRs) for lsr in request.lsrs: if omegaDB.hasLatentDB(trace.spAt(node)): latentDB = omegaDB.getLatentDB(trace.spAt(node)) else: latentDB = None weight += trace.spAt(node).simulateLatents(trace.argsAt(node), lsr, shouldRestore, latentDB) return ensure_python_float(weight)
def weight(self, _trace, newValue, args): answer = self.psp.logDensity(newValue, args) return ensure_python_float(answer)