Пример #1
0
def computeRegenCounts(trace, drg, absorbing, aaa, border, brush, hardBorder):
    regenCounts = OrderedDict()
    for node in drg:
        if node in aaa:
            regenCounts[node] = 1  # will be added to shortly
        elif node in hardBorder:
            # hardBorder nodes will regenerate despite the number of children.
            regenCounts[node] = 1
        elif node in border:
            regenCounts[node] = len(trace.childrenAt(node)) + 1
        else:
            regenCounts[node] = len(trace.childrenAt(node))

    if aaa:
        for node in drg.union(absorbing):
            for parent in trace.parentsAt(node):
                maybeIncrementAAARegenCount(trace, regenCounts, aaa, parent)

        for node in brush:
            if isOutputNode(node):
                for esrParent in trace.esrParentsAt(node):
                    maybeIncrementAAARegenCount(trace, regenCounts, aaa,
                                                esrParent)
            elif isLookupNode(node):
                maybeIncrementAAARegenCount(trace, regenCounts, aaa,
                                            node.sourceNode)

    return regenCounts
Пример #2
0
def extract(trace, node, scaffold, omegaDB, compute_gradient=False):
    weight = 0
    weight += maybeExtractStaleAAA(trace, node, scaffold, omegaDB,
                                   compute_gradient)

    if scaffold.isResampling(node):
        trace.decRegenCountAt(scaffold, node)
        assert trace.regenCountAt(scaffold, node) >= 0
        if trace.regenCountAt(scaffold, node) == 0:
            if isApplicationNode(node):
                if isRequestNode(node):
                    weight += unevalRequests(trace, node, scaffold, omegaDB,
                                             compute_gradient)
                weight += unapplyPSP(trace, node, scaffold, omegaDB,
                                     compute_gradient)
            else:
                trace.setValueAt(node, None)
                assert isLookupNode(node) or isConstantNode(node)
                assert len(trace.parentsAt(node)) <= 1
                if compute_gradient:
                    for p in trace.parentsAt(node):
                        omegaDB.addPartial(p, omegaDB.getPartial(
                            node))  # d/dx is 1 for a lookup node
            weight += extractParents(trace, node, scaffold, omegaDB,
                                     compute_gradient)

    return weight
Пример #3
0
def restore(trace, node, scaffold, omegaDB, gradients):
  if isConstantNode(node): return 0
  if isLookupNode(node):
    weight = regenParents(trace, node, scaffold, True, omegaDB, gradients)
    trace.reconnectLookup(node)
    trace.setValueAt(node, trace.valueAt(node.sourceNode))
    return ensure_python_float(weight)
  else: # node is output node
    assert isOutputNode(node)
    weight = restore(trace, node.operatorNode, scaffold, omegaDB, gradients)
    for operandNode in node.operandNodes:
      weight += restore(trace, operandNode, scaffold, omegaDB, gradients)
    weight += apply(trace, node.requestNode, node, scaffold,
                    True, omegaDB, gradients)
    return ensure_python_float(weight)
Пример #4
0
 def getOutermostNonReferenceNode(self, node):
     if isConstantNode(node): return node
     if isLookupNode(node):
         return self.getOutermostNonReferenceNode(node.sourceNode)
     assert isOutputNode(node)
     if isinstance(self.pspAt(node), ESRRefOutputPSP):
         if self.esrParentsAt(node):
             return self.getOutermostNonReferenceNode(
                 self.esrParentsAt(node)[0])
         else:
             # Could happen if this method is called on a torus, e.g. for rejection sampling
             raise infer.MissingEsrParentError()
     elif isTagOutputPSP(self.pspAt(node)):
         return self.getOutermostNonReferenceNode(node.operandNodes[2])
     else:
         return node
Пример #5
0
def propagateConstraint(trace, node, value):
  if isLookupNode(node): trace.setValueAt(node, value)
  elif isRequestNode(node):
    if not isinstance(trace.pspAt(node), NullRequestPSP):
      raise VentureException("evaluation", "Cannot make requests " \
        "downstream of a node that gets constrained during regen",
        address=node.address)
  else:
    # TODO there may be more cases to ban here.
    # e.g. certain kinds of deterministic coupling through mutation.
    assert isOutputNode(node)
    if trace.pspAt(node).isRandom():
      raise VentureException("evaluation", "Cannot make random choices " \
        "downstream of a node that gets constrained during regen",
        address=node.address)
    # TODO Is it necessary to unincorporate and incorporate here?  If
    # not, why not?
    trace.setValueAt(node, trace.pspAt(node).simulate(trace.argsAt(node)))
  for child in trace.childrenAt(node): propagateConstraint(trace, child, value)
Пример #6
0
def regen(trace, node, scaffold, shouldRestore, omegaDB, gradients):
  weight = 0
  if scaffold.isResampling(node):
    if trace.regenCountAt(scaffold, node) == 0:
      weight += regenParents(trace, node, scaffold,
                             shouldRestore, omegaDB, gradients)
      if isLookupNode(node):
        propagateLookup(trace, node)
      else:
        weight += applyPSP(trace, node, scaffold,
                           shouldRestore, omegaDB, gradients)
        if isRequestNode(node):
          weight += evalRequests(trace, node, scaffold,
                                 shouldRestore, omegaDB, gradients)
    trace.incRegenCountAt(scaffold, node)
  weight += maybeRegenStaleAAA(trace, node, scaffold,
                               shouldRestore, omegaDB, gradients)

  return ensure_python_float(weight)
Пример #7
0
def updateValueAtNode(trace, scaffold, node, updatedNodes):
    # Strong assumption! Only consider resampling nodes in the scaffold.
    if node not in updatedNodes and scaffold.isResampling(node):
        if isLookupNode(node):
            updateValueAtNode(trace, scaffold, node.sourceNode, updatedNodes)
            trace.setValueAt(node, trace.valueAt(node.sourceNode))
        elif isOutputNode(node):
            # Assume SPRef and AAA nodes are always updated.
            psp = trace.pspAt(node)
            if not isinstance(trace.valueAt(node),
                              SPRef) and not psp.childrenCanAAA():
                canAbsorb = True
                for parent in trace.parentsAt(node):
                    if not psp.canAbsorb(trace, node, parent):
                        updateValueAtNode(trace, scaffold, parent,
                                          updatedNodes)
                        canAbsorb = False
                if not canAbsorb:
                    update(trace, node)
        updatedNodes.add(node)
Пример #8
0
def unevalFamily(trace, node, scaffold, omegaDB, compute_gradient=False):
    weight = 0
    if isConstantNode(node): pass
    elif isLookupNode(node):
        assert len(trace.parentsAt(node)) == 1
        if compute_gradient:
            for p in trace.parentsAt(node):
                omegaDB.addPartial(
                    p, omegaDB.getPartial(node))  # d/dx is 1 for a lookup node
        trace.disconnectLookup(node)
        trace.setValueAt(node, None)
        weight += extractParents(trace, node, scaffold, omegaDB,
                                 compute_gradient)
    else:
        assert isOutputNode(node)
        weight += unapply(trace, node, scaffold, omegaDB, compute_gradient)
        for operandNode in reversed(node.operandNodes):
            weight += unevalFamily(trace, operandNode, scaffold, omegaDB,
                                   compute_gradient)
        weight += unevalFamily(trace, node.operatorNode, scaffold, omegaDB,
                               compute_gradient)
    return weight
Пример #9
0
def extendCandidateScaffold(trace, pnodes, drg, absorbing, aaa,
                            indexAssignments, i, hardBorder):
    q = [(pnode, True, None) for pnode in pnodes]

    while q:
        node, isPrincipal, parentNode = q.pop()
        if node in drg and node not in aaa:
            addResamplingNode(trace, drg, absorbing, aaa, q, node,
                              indexAssignments, i, hardBorder)
        elif isLookupNode(node) or node.operatorNode in drg:
            addResamplingNode(trace, drg, absorbing, aaa, q, node,
                              indexAssignments, i, hardBorder)
        # TODO temporary: once we put all uncollapsed AAA procs into AEKernels, this line won't be necessary
        elif node in aaa:
            addAAANode(drg, aaa, absorbing, node, indexAssignments, i)
        elif (not isPrincipal) and trace.pspAt(node).canAbsorb(
                trace, node, parentNode):
            addAbsorbingNode(drg, absorbing, aaa, node, indexAssignments, i)
        elif trace.pspAt(node).childrenCanAAA():
            addAAANode(drg, aaa, absorbing, node, indexAssignments, i)
        else:
            addResamplingNode(trace, drg, absorbing, aaa, q, node,
                              indexAssignments, i, hardBorder)
Пример #10
0
 def doit(node):
     if n.isLookupNode(node):
         # The definite parents are the lookup source
         return OrderedFrozenSet([trace.definiteParentsAt(node)[0]])
     else:
         return OrderedFrozenSet([])