def computeRegenCounts(trace, drg, absorbing, aaa, border, brush, hardBorder): regenCounts = OrderedDict() for node in drg: if node in aaa: regenCounts[node] = 1 # will be added to shortly elif node in hardBorder: # hardBorder nodes will regenerate despite the number of children. regenCounts[node] = 1 elif node in border: regenCounts[node] = len(trace.childrenAt(node)) + 1 else: regenCounts[node] = len(trace.childrenAt(node)) if aaa: for node in drg.union(absorbing): for parent in trace.parentsAt(node): maybeIncrementAAARegenCount(trace, regenCounts, aaa, parent) for node in brush: if isOutputNode(node): for esrParent in trace.esrParentsAt(node): maybeIncrementAAARegenCount(trace, regenCounts, aaa, esrParent) elif isLookupNode(node): maybeIncrementAAARegenCount(trace, regenCounts, aaa, node.sourceNode) return regenCounts
def extract(trace, node, scaffold, omegaDB, compute_gradient=False): weight = 0 weight += maybeExtractStaleAAA(trace, node, scaffold, omegaDB, compute_gradient) if scaffold.isResampling(node): trace.decRegenCountAt(scaffold, node) assert trace.regenCountAt(scaffold, node) >= 0 if trace.regenCountAt(scaffold, node) == 0: if isApplicationNode(node): if isRequestNode(node): weight += unevalRequests(trace, node, scaffold, omegaDB, compute_gradient) weight += unapplyPSP(trace, node, scaffold, omegaDB, compute_gradient) else: trace.setValueAt(node, None) assert isLookupNode(node) or isConstantNode(node) assert len(trace.parentsAt(node)) <= 1 if compute_gradient: for p in trace.parentsAt(node): omegaDB.addPartial(p, omegaDB.getPartial( node)) # d/dx is 1 for a lookup node weight += extractParents(trace, node, scaffold, omegaDB, compute_gradient) return weight
def restore(trace, node, scaffold, omegaDB, gradients): if isConstantNode(node): return 0 if isLookupNode(node): weight = regenParents(trace, node, scaffold, True, omegaDB, gradients) trace.reconnectLookup(node) trace.setValueAt(node, trace.valueAt(node.sourceNode)) return ensure_python_float(weight) else: # node is output node assert isOutputNode(node) weight = restore(trace, node.operatorNode, scaffold, omegaDB, gradients) for operandNode in node.operandNodes: weight += restore(trace, operandNode, scaffold, omegaDB, gradients) weight += apply(trace, node.requestNode, node, scaffold, True, omegaDB, gradients) return ensure_python_float(weight)
def getOutermostNonReferenceNode(self, node): if isConstantNode(node): return node if isLookupNode(node): return self.getOutermostNonReferenceNode(node.sourceNode) assert isOutputNode(node) if isinstance(self.pspAt(node), ESRRefOutputPSP): if self.esrParentsAt(node): return self.getOutermostNonReferenceNode( self.esrParentsAt(node)[0]) else: # Could happen if this method is called on a torus, e.g. for rejection sampling raise infer.MissingEsrParentError() elif isTagOutputPSP(self.pspAt(node)): return self.getOutermostNonReferenceNode(node.operandNodes[2]) else: return node
def propagateConstraint(trace, node, value): if isLookupNode(node): trace.setValueAt(node, value) elif isRequestNode(node): if not isinstance(trace.pspAt(node), NullRequestPSP): raise VentureException("evaluation", "Cannot make requests " \ "downstream of a node that gets constrained during regen", address=node.address) else: # TODO there may be more cases to ban here. # e.g. certain kinds of deterministic coupling through mutation. assert isOutputNode(node) if trace.pspAt(node).isRandom(): raise VentureException("evaluation", "Cannot make random choices " \ "downstream of a node that gets constrained during regen", address=node.address) # TODO Is it necessary to unincorporate and incorporate here? If # not, why not? trace.setValueAt(node, trace.pspAt(node).simulate(trace.argsAt(node))) for child in trace.childrenAt(node): propagateConstraint(trace, child, value)
def regen(trace, node, scaffold, shouldRestore, omegaDB, gradients): weight = 0 if scaffold.isResampling(node): if trace.regenCountAt(scaffold, node) == 0: weight += regenParents(trace, node, scaffold, shouldRestore, omegaDB, gradients) if isLookupNode(node): propagateLookup(trace, node) else: weight += applyPSP(trace, node, scaffold, shouldRestore, omegaDB, gradients) if isRequestNode(node): weight += evalRequests(trace, node, scaffold, shouldRestore, omegaDB, gradients) trace.incRegenCountAt(scaffold, node) weight += maybeRegenStaleAAA(trace, node, scaffold, shouldRestore, omegaDB, gradients) return ensure_python_float(weight)
def updateValueAtNode(trace, scaffold, node, updatedNodes): # Strong assumption! Only consider resampling nodes in the scaffold. if node not in updatedNodes and scaffold.isResampling(node): if isLookupNode(node): updateValueAtNode(trace, scaffold, node.sourceNode, updatedNodes) trace.setValueAt(node, trace.valueAt(node.sourceNode)) elif isOutputNode(node): # Assume SPRef and AAA nodes are always updated. psp = trace.pspAt(node) if not isinstance(trace.valueAt(node), SPRef) and not psp.childrenCanAAA(): canAbsorb = True for parent in trace.parentsAt(node): if not psp.canAbsorb(trace, node, parent): updateValueAtNode(trace, scaffold, parent, updatedNodes) canAbsorb = False if not canAbsorb: update(trace, node) updatedNodes.add(node)
def unevalFamily(trace, node, scaffold, omegaDB, compute_gradient=False): weight = 0 if isConstantNode(node): pass elif isLookupNode(node): assert len(trace.parentsAt(node)) == 1 if compute_gradient: for p in trace.parentsAt(node): omegaDB.addPartial( p, omegaDB.getPartial(node)) # d/dx is 1 for a lookup node trace.disconnectLookup(node) trace.setValueAt(node, None) weight += extractParents(trace, node, scaffold, omegaDB, compute_gradient) else: assert isOutputNode(node) weight += unapply(trace, node, scaffold, omegaDB, compute_gradient) for operandNode in reversed(node.operandNodes): weight += unevalFamily(trace, operandNode, scaffold, omegaDB, compute_gradient) weight += unevalFamily(trace, node.operatorNode, scaffold, omegaDB, compute_gradient) return weight
def extendCandidateScaffold(trace, pnodes, drg, absorbing, aaa, indexAssignments, i, hardBorder): q = [(pnode, True, None) for pnode in pnodes] while q: node, isPrincipal, parentNode = q.pop() if node in drg and node not in aaa: addResamplingNode(trace, drg, absorbing, aaa, q, node, indexAssignments, i, hardBorder) elif isLookupNode(node) or node.operatorNode in drg: addResamplingNode(trace, drg, absorbing, aaa, q, node, indexAssignments, i, hardBorder) # TODO temporary: once we put all uncollapsed AAA procs into AEKernels, this line won't be necessary elif node in aaa: addAAANode(drg, aaa, absorbing, node, indexAssignments, i) elif (not isPrincipal) and trace.pspAt(node).canAbsorb( trace, node, parentNode): addAbsorbingNode(drg, absorbing, aaa, node, indexAssignments, i) elif trace.pspAt(node).childrenCanAAA(): addAAANode(drg, aaa, absorbing, node, indexAssignments, i) else: addResamplingNode(trace, drg, absorbing, aaa, q, node, indexAssignments, i, hardBorder)
def doit(node): if n.isLookupNode(node): # The definite parents are the lookup source return OrderedFrozenSet([trace.definiteParentsAt(node)[0]]) else: return OrderedFrozenSet([])