예제 #1
0
 def doit(node):
     if n.isOutputNode(node):
         # The last definite parent is the request node
         return OrderedFrozenSet(
             [trace.definiteParentsAt(node)[-1]])
     else:
         return OrderedFrozenSet([])
예제 #2
0
def computeRegenCounts(trace, drg, absorbing, aaa, border, brush, hardBorder):
    regenCounts = OrderedDict()
    for node in drg:
        if node in aaa:
            regenCounts[node] = 1  # will be added to shortly
        elif node in hardBorder:
            # hardBorder nodes will regenerate despite the number of children.
            regenCounts[node] = 1
        elif node in border:
            regenCounts[node] = len(trace.childrenAt(node)) + 1
        else:
            regenCounts[node] = len(trace.childrenAt(node))

    if aaa:
        for node in drg.union(absorbing):
            for parent in trace.parentsAt(node):
                maybeIncrementAAARegenCount(trace, regenCounts, aaa, parent)

        for node in brush:
            if isOutputNode(node):
                for esrParent in trace.esrParentsAt(node):
                    maybeIncrementAAARegenCount(trace, regenCounts, aaa,
                                                esrParent)
            elif isLookupNode(node):
                maybeIncrementAAARegenCount(trace, regenCounts, aaa,
                                            node.sourceNode)

    return regenCounts
예제 #3
0
def disableFamily(trace, node, disableCounts, disabledRequests, brush):
    if node in brush: return
    brush.add(node)
    if isOutputNode(node):
        brush.add(node.requestNode)
        disableRequests(trace, node.requestNode, disableCounts,
                        disabledRequests, brush)
        disableFamily(trace, node.operatorNode, disableCounts,
                      disabledRequests, brush)
        for operandNode in node.operandNodes:
            disableFamily(trace, operandNode, disableCounts, disabledRequests,
                          brush)
예제 #4
0
def loadKernels(trace, drg, aaa, useDeltaKernels, deltaKernelArgs):
    lkernels = OrderedDict(
        (node, trace.pspAt(node).getAAALKernel()) for node in aaa)
    if useDeltaKernels:
        for node in drg - aaa:
            if not isOutputNode(node): continue
            if node.operatorNode in drg: continue
            # If you're wondering about this fallback clause, the rationale
            # is in the "joint-delta-kernels" footnote of doc/on-latents.md
            for o in node.operandNodes:
                if o in drg: continue
            if trace.pspAt(node).hasDeltaKernel():
                lkernels[node] = trace.pspAt(node).getDeltaKernel(
                    deltaKernelArgs)
    return lkernels
예제 #5
0
def restore(trace, node, scaffold, omegaDB, gradients):
  if isConstantNode(node): return 0
  if isLookupNode(node):
    weight = regenParents(trace, node, scaffold, True, omegaDB, gradients)
    trace.reconnectLookup(node)
    trace.setValueAt(node, trace.valueAt(node.sourceNode))
    return ensure_python_float(weight)
  else: # node is output node
    assert isOutputNode(node)
    weight = restore(trace, node.operatorNode, scaffold, omegaDB, gradients)
    for operandNode in node.operandNodes:
      weight += restore(trace, operandNode, scaffold, omegaDB, gradients)
    weight += apply(trace, node.requestNode, node, scaffold,
                    True, omegaDB, gradients)
    return ensure_python_float(weight)
예제 #6
0
 def getOutermostNonReferenceNode(self, node):
     if isConstantNode(node): return node
     if isLookupNode(node):
         return self.getOutermostNonReferenceNode(node.sourceNode)
     assert isOutputNode(node)
     if isinstance(self.pspAt(node), ESRRefOutputPSP):
         if self.esrParentsAt(node):
             return self.getOutermostNonReferenceNode(
                 self.esrParentsAt(node)[0])
         else:
             # Could happen if this method is called on a torus, e.g. for rejection sampling
             raise infer.MissingEsrParentError()
     elif isTagOutputPSP(self.pspAt(node)):
         return self.getOutermostNonReferenceNode(node.operandNodes[2])
     else:
         return node
예제 #7
0
 def freeze(self, id):
     assert id in self.families
     node = self.families[id]
     if isConstantNode(node):
         # All set
         pass
     else:
         assert isOutputNode(node)
         value = self.valueAt(node)
         unevalFamily(self, node, Scaffold(), OmegaDB())
         # XXX It looks like we kinda want to replace the identity of this
         # node by a constant node, but we don't have a nice way to do that
         # so we fake it by dropping the components and marking it frozen.
         node.isFrozen = True
         self.setValueAt(node, value)
         node.requestNode = None
         node.operandNodes = None
         node.operatorNode = None
예제 #8
0
def propagateConstraint(trace, node, value):
  if isLookupNode(node): trace.setValueAt(node, value)
  elif isRequestNode(node):
    if not isinstance(trace.pspAt(node), NullRequestPSP):
      raise VentureException("evaluation", "Cannot make requests " \
        "downstream of a node that gets constrained during regen",
        address=node.address)
  else:
    # TODO there may be more cases to ban here.
    # e.g. certain kinds of deterministic coupling through mutation.
    assert isOutputNode(node)
    if trace.pspAt(node).isRandom():
      raise VentureException("evaluation", "Cannot make random choices " \
        "downstream of a node that gets constrained during regen",
        address=node.address)
    # TODO Is it necessary to unincorporate and incorporate here?  If
    # not, why not?
    trace.setValueAt(node, trace.pspAt(node).simulate(trace.argsAt(node)))
  for child in trace.childrenAt(node): propagateConstraint(trace, child, value)
예제 #9
0
def updateValueAtNode(trace, scaffold, node, updatedNodes):
    # Strong assumption! Only consider resampling nodes in the scaffold.
    if node not in updatedNodes and scaffold.isResampling(node):
        if isLookupNode(node):
            updateValueAtNode(trace, scaffold, node.sourceNode, updatedNodes)
            trace.setValueAt(node, trace.valueAt(node.sourceNode))
        elif isOutputNode(node):
            # Assume SPRef and AAA nodes are always updated.
            psp = trace.pspAt(node)
            if not isinstance(trace.valueAt(node),
                              SPRef) and not psp.childrenCanAAA():
                canAbsorb = True
                for parent in trace.parentsAt(node):
                    if not psp.canAbsorb(trace, node, parent):
                        updateValueAtNode(trace, scaffold, parent,
                                          updatedNodes)
                        canAbsorb = False
                if not canAbsorb:
                    update(trace, node)
        updatedNodes.add(node)
예제 #10
0
def unevalFamily(trace, node, scaffold, omegaDB, compute_gradient=False):
    weight = 0
    if isConstantNode(node): pass
    elif isLookupNode(node):
        assert len(trace.parentsAt(node)) == 1
        if compute_gradient:
            for p in trace.parentsAt(node):
                omegaDB.addPartial(
                    p, omegaDB.getPartial(node))  # d/dx is 1 for a lookup node
        trace.disconnectLookup(node)
        trace.setValueAt(node, None)
        weight += extractParents(trace, node, scaffold, omegaDB,
                                 compute_gradient)
    else:
        assert isOutputNode(node)
        weight += unapply(trace, node, scaffold, omegaDB, compute_gradient)
        for operandNode in reversed(node.operandNodes):
            weight += unevalFamily(trace, operandNode, scaffold, omegaDB,
                                   compute_gradient)
        weight += unevalFamily(trace, node.operatorNode, scaffold, omegaDB,
                               compute_gradient)
    return weight
예제 #11
0
    def addRandomChoicesInExtent(self, node, scope, block, pnodes):
        if not isOutputNode(node): return

        if self.pspAt(node).isRandom() and node not in self.ccs:
            pnodes.add(node)

        requestNode = node.requestNode
        if self.pspAt(requestNode).isRandom() and requestNode not in self.ccs:
            pnodes.add(requestNode)

        for esr in self.valueAt(node.requestNode).esrs:
            self.addRandomChoicesInExtent(self.spFamilyAt(requestNode, esr.id),
                                          scope, block, pnodes)

        self.addRandomChoicesInExtent(node.operatorNode, scope, block, pnodes)

        for i, operandNode in enumerate(node.operandNodes):
            if i == 2 and isTagOutputPSP(self.pspAt(node)):
                (new_scope, new_block, _) = [
                    self.valueAt(randNode) for randNode in node.operandNodes
                ]
                (new_scope, new_block) = self._normalizeEvaluatedScopeAndBlock(
                    new_scope, new_block)
                if scope != new_scope or block == new_block:
                    self.addRandomChoicesInExtent(operandNode, scope, block,
                                                  pnodes)
            elif i == 1 and isTagExcludeOutputPSP(self.pspAt(node)):
                (excluded_scope, _) = [
                    self.valueAt(randNode) for randNode in node.operandNodes
                ]
                excluded_scope = self._normalizeEvaluatedScope(excluded_scope)
                if scope != excluded_scope:
                    self.addRandomChoicesInExtent(operandNode, scope, block,
                                                  pnodes)
            else:
                self.addRandomChoicesInExtent(operandNode, scope, block,
                                              pnodes)