Example #1
0
    def __init__(self):
        self.engine_type = 'reduced traces'
        self.assumes = {}  # id -> evalnode
        self.observes = {}  # id -> evalnode
        self.predicts = {}  # id -> evalnode
        self.directives = []

        self.db = RandomChoiceDict()
        self.weighted_db = WeightedRandomChoiceDict()
        self.choices = {}  # hash -> evalnode
        self.xrps = {}  # hash -> (xrp, set of application nodes)

        self.env = EnvironmentNode()

        self.p = 0
        self.uneval_p = 0
        self.eval_p = 0
        self.new_to_old_q = 0
        self.old_to_new_q = 0

        self.debug = False

        # necessary because of the new XRP interface requiring some state kept while doing inference
        self.application_reflip = False
        self.reflip_node = ReducedEvalNode(self, self.env, VarExpression(''))
        self.nodes = []
        self.old_vals = [Value()]
        self.new_vals = [Value()]
        self.old_val = Value()
        self.new_val = Value()
        self.reflip_xrp = XRP()

        self.mhstats_details = False
        self.mhstats = {}
        self.made_proposals = 0
        self.accepted_proposals = 0

        self.hashval = rrandom.random.randbelow()
        return
Example #2
0
    def __init__(self):
        self.engine_type = "traces"

        self.assumes = {}  # id -> evalnode
        self.observes = {}  # id -> evalnode
        self.predicts = {}  # id -> evalnode
        self.directives = []

        self.db = RandomChoiceDict()
        self.weighted_db = WeightedRandomChoiceDict()
        self.choices = {}  # hash -> evalnode
        self.xrps = {}  # hash -> (xrp, set of application nodes)

        self.env = EnvironmentNode()

        self.p = 0
        self.old_to_new_q = 0
        self.new_to_old_q = 0

        self.eval_xrps = []  # (xrp, args, val)
        self.uneval_xrps = []  # (xrp, args, val)

        self.debug = False

        # Stuff for restoring
        self.application_reflip = False
        self.reflip_node = EvalNode(self, self.env, VarExpression(""))
        self.nodes = []
        self.old_vals = []
        self.new_vals = []
        self.old_val = Value()
        self.new_val = Value()
        self.reflip_xrp = XRP()

        self.made_proposals = 0
        self.accepted_proposals = 0

        self.mhstats_details = False
        self.mhstats = {}
        return
  def __init__(self):
    self.engine_type = 'reduced traces'
    self.assumes = {} # id -> evalnode
    self.observes = {} # id -> evalnode
    self.predicts = {} # id -> evalnode
    self.directives = []

    self.db = RandomChoiceDict() 
    self.weighted_db = WeightedRandomChoiceDict() 
    self.choices = {} # hash -> evalnode
    self.xrps = {} # hash -> (xrp, set of application nodes)

    self.env = EnvironmentNode()

    self.p = 0
    self.uneval_p = 0
    self.eval_p = 0
    self.new_to_old_q = 0
    self.old_to_new_q = 0

    self.debug = False

    # necessary because of the new XRP interface requiring some state kept while doing inference
    self.application_reflip = False
    self.reflip_node = ReducedEvalNode(self, self.env, VarExpression(''))
    self.nodes = []
    self.old_vals = [Value()]
    self.new_vals = [Value()]
    self.old_val = Value() 
    self.new_val = Value() 
    self.reflip_xrp = XRP()

    self.mhstats_details = False
    self.mhstats = {}
    self.made_proposals = 0
    self.accepted_proposals = 0

    self.hashval = rrandom.random.randbelow()
    return
Example #4
0
class ReducedTraces(Engine):
    def __init__(self):
        self.engine_type = 'reduced traces'
        self.assumes = {}  # id -> evalnode
        self.observes = {}  # id -> evalnode
        self.predicts = {}  # id -> evalnode
        self.directives = []

        self.db = RandomChoiceDict()
        self.weighted_db = WeightedRandomChoiceDict()
        self.choices = {}  # hash -> evalnode
        self.xrps = {}  # hash -> (xrp, set of application nodes)

        self.env = EnvironmentNode()

        self.p = 0
        self.uneval_p = 0
        self.eval_p = 0
        self.new_to_old_q = 0
        self.old_to_new_q = 0

        self.debug = False

        # necessary because of the new XRP interface requiring some state kept while doing inference
        self.application_reflip = False
        self.reflip_node = ReducedEvalNode(self, self.env, VarExpression(''))
        self.nodes = []
        self.old_vals = [Value()]
        self.new_vals = [Value()]
        self.old_val = Value()
        self.new_val = Value()
        self.reflip_xrp = XRP()

        self.mhstats_details = False
        self.mhstats = {}
        self.made_proposals = 0
        self.accepted_proposals = 0

        self.hashval = rrandom.random.randbelow()
        return

    def mhstats_on(self):
        self.mhstats_details = True

    def mhstats_off(self):
        self.mhstats_details = False

    def mhstats_aggregated(self):
        d = {}
        d['made-proposals'] = self.made_proposals
        d['accepted-proposals'] = self.accepted_proposals
        return d

    def mhstats_detailed(self):
        return self.mhstats

    def get_log_score(self, id):
        if id == -1:
            return self.p
        else:
            return self.observes[id].p

    def weight(self):
        return self.db.weight() + self.weighted_db.weight()

    def random_choices(self):
        return self.db.__len__() + self.weighted_db.__len__()

    def assume(self, name, expr, id=-1):
        evalnode = ReducedEvalNode(self, self.env, expr)
        self.env.add_assume(name, evalnode)
        evalnode.add_assume(name, self.env)

        if id != -1:
            self.assumes[id] = evalnode
            assert id == len(self.directives)
            self.directives.append('assume')
        val = evalnode.evaluate()
        return val

    def predict(self, expr, id):
        evalnode = ReducedEvalNode(self, self.env, expr)

        assert id == len(self.directives)
        self.directives.append('predict')
        self.predicts[id] = evalnode

        evalnode.predict = True
        val = evalnode.evaluate()
        return val

    def observe(self, expr, obs_val, id):
        evalnode = ReducedEvalNode(self, self.env, expr)

        assert id == len(self.directives)
        self.directives.append('observe')
        self.observes[id] = evalnode

        evalnode.observed = True
        evalnode.observe_val = obs_val

        val = evalnode.evaluate()
        return val

    def forget(self, id):
        if id in self.observes:
            d = self.observes
        elif id in self.predicts:
            d = self.predicts
        else:
            raise RException("Can only forget predicts and observes")
        evalnode = d[id]
        evalnode.unevaluate()
        #del d[id]
        return

    def report_value(self, id):
        node = self.get_directive_node(id)
        if not node.active:
            raise RException("Error.  Perhaps this directive was forgotten?")
        val = node.val
        return val

    def get_directive_node(self, id):
        if self.directives[id] == 'assume':
            node = self.assumes[id]
        elif self.directives[id] == 'observe':
            node = self.observes[id]
        else:
            assert self.directives[id] == 'predict'
            node = self.predicts[id]
        return node

    def add_accepted_proposal(self, hashval):
        if self.mhstats_details:
            if hashval in self.mhstats:
                self.mhstats[hashval]['accepted-proposals'] += 1
            else:
                self.mhstats[hashval] = {}
                self.mhstats[hashval]['accepted-proposals'] = 1
                self.mhstats[hashval]['made-proposals'] = 0
        self.accepted_proposals += 1

    def add_made_proposal(self, hashval):
        if self.mhstats_details:
            if hashval in self.mhstats:
                self.mhstats[hashval]['made-proposals'] += 1
            else:
                self.mhstats[hashval] = {}
                self.mhstats[hashval]['made-proposals'] = 1
                self.mhstats[hashval]['accepted-proposals'] = 0
        self.made_proposals += 1

    def reflip(self, hashval):
        if self.debug:
            print self

        if hashval in self.choices:
            self.application_reflip = True
            self.reflip_node = self.choices[hashval]
            if not self.reflip_node.random_xrp_apply:
                raise RException(
                    "Reflipping something which isn't a random xrp application"
                )
            if self.reflip_node.val is None:
                raise RException(
                    "Reflipping something which previously had value None")
        else:
            self.application_reflip = False  # internal reflip
            (self.reflip_xrp, nodes) = self.xrps[hashval]
            self.nodes = list_nodes(nodes)

        self.eval_p = 0
        self.uneval_p = 0

        old_p = self.p
        self.old_to_new_q = -math.log(self.weight())

        if self.application_reflip:
            self.old_val = self.reflip_node.val
            self.new_val = self.reflip_node.reflip()
        else:
            # TODO: this is copied from traces.  is it correct?
            args_list = []
            self.old_vals = []
            for node in self.nodes:
                args_list.append(node.args)
                self.old_vals.append(node.val)
            self.old_to_new_q += math.log(self.reflip_xrp.state_weight())
            old_p += self.reflip_xrp.theta_prob()

            self.new_vals, q_forwards, q_back = self.reflip_xrp.theta_mh_prop(
                args_list, self.old_vals)
            self.old_to_new_q += q_forwards
            self.new_to_old_q += q_back

            for i in range(len(self.nodes)):
                node = self.nodes[i]
                val = self.new_vals[i]
                node.set_val(val)
                node.propagate_up(False)

        new_p = self.p
        self.new_to_old_q = -math.log(self.weight())
        self.old_to_new_q += self.eval_p
        self.new_to_old_q += self.uneval_p

        if not self.application_reflip:
            new_p += self.reflip_xrp.theta_prob()
            self.new_to_old_q += math.log(self.reflip_xrp.state_weight())

        if self.debug:
            if self.application_reflip:
                print "\nCHANGING VAL OF ", self.reflip_node, "\n  FROM  :  ", self.old_val, "\n  TO   :  ", self.new_val, "\n"
                if (self.old_val.__eq__(self.new_val)).bool:
                    print "SAME VAL"
            else:
                print "TRANSITIONING STATE OF ", self.reflip_xrp

            print "new db", self
            print "\nq(old -> new) : ", math.exp(self.old_to_new_q)
            print "q(new -> old) : ", math.exp(self.new_to_old_q)
            print "p(old) : ", math.exp(old_p)
            print "p(new) : ", math.exp(new_p)
            print 'transition prob : ', math.exp(new_p + self.new_to_old_q -
                                                 old_p -
                                                 self.old_to_new_q), "\n"
            print "\n-----------------------------------------\n"

        return old_p, self.old_to_new_q, new_p, self.new_to_old_q

    def restore(self):
        if self.application_reflip:
            self.reflip_node.restore(self.old_val)
        else:
            for i in range(len(self.nodes)):
                node = self.nodes[i]
                node.set_val(self.old_vals[i])
                node.propagate_up(True)
            self.reflip_xrp.theta_mh_restore()
        if self.debug:
            print 'restore'

    def keep(self):
        if self.application_reflip:
            self.reflip_node.restore(
                self.new_val
            )  # NOTE: Is necessary for correctness, as we must forget old branch
        else:
            for i in range(len(self.nodes)):
                node = self.nodes[i]
                node.restore(self.new_vals[i])
            self.reflip_xrp.theta_mh_keep()

    def add_for_transition(self, xrp, evalnode):
        hashval = xrp.__hash__()
        if hashval not in self.xrps:
            self.xrps[hashval] = (xrp, {})
        (xrp, evalnodes) = self.xrps[hashval]
        evalnodes[evalnode] = True

        weight = xrp.state_weight()
        self.delete_from_db(xrp.__hash__())
        self.add_to_db(xrp.__hash__(), weight)

    # Add an XRP application node to the db
    def add_xrp(self, xrp, args, evalnode):
        weight = xrp.weight(args)
        evalnode.setargs(args)
        try:
            self.new_to_old_q += math.log(weight)
        except:
            pass  # This is only necessary if we're reflipping

        if self.weighted_db.__contains__(
                evalnode.hashval) or self.db.__contains__(
                    evalnode.hashval) or (evalnode.hashval in self.choices):
            raise RException("DB already had this evalnode")
        self.choices[evalnode.hashval] = evalnode
        self.add_to_db(evalnode.hashval, weight)

    def add_to_db(self, hashval, weight):
        if weight == 0:
            return
        elif weight == 1:
            self.db.__setitem__(hashval, True)
        else:
            self.weighted_db.__setitem__(hashval, True, weight)

    def remove_for_transition(self, xrp, evalnode):
        hashval = xrp.__hash__()
        (xrp, evalnodes) = self.xrps[hashval]
        del evalnodes[evalnode]
        if len(evalnodes) == 0:
            del self.xrps[hashval]
            self.delete_from_db(xrp.__hash__())

    def remove_xrp(self, evalnode):
        xrp = evalnode.xrp
        try:
            self.old_to_new_q += math.log(xrp.weight(evalnode.args))
        except:
            pass  # This fails when restoring/keeping, for example

        if evalnode.hashval not in self.choices:
            raise RException("Choices did not already have this evalnode")
        else:
            del self.choices[evalnode.hashval]
            self.delete_from_db(evalnode.hashval)

    def delete_from_db(self, hashval):
        if self.db.__contains__(hashval):
            self.db.__delitem__(hashval)
        elif self.weighted_db.__contains__(hashval):
            self.weighted_db.__delitem__(hashval)

    def randomKey(self):
        if rrandom.random.random() * self.weight() > self.db.weight():
            return self.weighted_db.randomKey()
        else:
            return self.db.randomKey()

    def infer(self):
        try:
            hashval = self.randomKey()
        except:
            raise RException("Program has no randomness!")

        old_p, old_to_new_q, new_p, new_to_old_q = self.reflip(hashval)
        p = rrandom.random.random()
        if new_p + new_to_old_q - old_p - old_to_new_q < math.log(p):
            self.restore()
        else:
            self.keep()
            self.add_accepted_proposal(hashval)
        self.add_made_proposal(hashval)

    def reset(self):
        self.__init__()

    def __str__(self):
        string = "EvalNodeTree:"
        for evalnode in self.assumes.values():
            string += evalnode.str_helper()
        for evalnode in self.observes.values():
            string += evalnode.str_helper()
        for evalnode in self.predicts.values():
            string += evalnode.str_helper()
        return string
class ReducedTraces(Engine):
  def __init__(self):
    self.engine_type = 'reduced traces'
    self.assumes = {} # id -> evalnode
    self.observes = {} # id -> evalnode
    self.predicts = {} # id -> evalnode
    self.directives = []

    self.db = RandomChoiceDict() 
    self.weighted_db = WeightedRandomChoiceDict() 
    self.choices = {} # hash -> evalnode
    self.xrps = {} # hash -> (xrp, set of application nodes)

    self.env = EnvironmentNode()

    self.p = 0
    self.uneval_p = 0
    self.eval_p = 0
    self.new_to_old_q = 0
    self.old_to_new_q = 0

    self.debug = False

    # necessary because of the new XRP interface requiring some state kept while doing inference
    self.application_reflip = False
    self.reflip_node = ReducedEvalNode(self, self.env, VarExpression(''))
    self.nodes = []
    self.old_vals = [Value()]
    self.new_vals = [Value()]
    self.old_val = Value() 
    self.new_val = Value() 
    self.reflip_xrp = XRP()

    self.mhstats_details = False
    self.mhstats = {}
    self.made_proposals = 0
    self.accepted_proposals = 0

    self.hashval = rrandom.random.randbelow()
    return

  def mhstats_on(self):
    self.mhstats_details = True

  def mhstats_off(self):
    self.mhstats_details = False

  def mhstats_aggregated(self):
    d = {}
    d['made-proposals'] = self.made_proposals
    d['accepted-proposals'] = self.accepted_proposals
    return d

  def mhstats_detailed(self):
    return self.mhstats

  def get_log_score(self, id):
    if id == -1:
      return self.p
    else:
      return self.observes[id].p

  def weight(self):
    return self.db.weight() + self.weighted_db.weight()

  def random_choices(self):
    return self.db.__len__() + self.weighted_db.__len__()

  def assume(self, name, expr, id = -1):
    evalnode = ReducedEvalNode(self, self.env, expr)
    self.env.add_assume(name, evalnode)
    evalnode.add_assume(name, self.env)

    if id != -1:
      self.assumes[id] = evalnode
      assert id == len(self.directives)
      self.directives.append('assume')
    val = evalnode.evaluate()
    return val

  def predict(self, expr, id):
    evalnode = ReducedEvalNode(self, self.env, expr)

    assert id == len(self.directives)
    self.directives.append('predict')
    self.predicts[id] = evalnode

    evalnode.predict = True
    val = evalnode.evaluate()
    return val

  def observe(self, expr, obs_val, id):
    evalnode = ReducedEvalNode(self, self.env, expr)

    assert id == len(self.directives)
    self.directives.append('observe')
    self.observes[id] = evalnode

    evalnode.observed = True
    evalnode.observe_val = obs_val

    val = evalnode.evaluate()
    return val

  def forget(self, id):
    if id in self.observes:
      d = self.observes
    elif id in self.predicts:
      d = self.predicts
    else:
      raise RException("Can only forget predicts and observes")
    evalnode = d[id]
    evalnode.unevaluate()
    #del d[id]
    return

  def report_value(self, id):
    node = self.get_directive_node(id)
    if not node.active:
      raise RException("Error.  Perhaps this directive was forgotten?")
    val = node.val
    return val

  def get_directive_node(self, id):
    if self.directives[id] == 'assume':
      node = self.assumes[id]
    elif self.directives[id] == 'observe':
      node = self.observes[id]
    else:
      assert self.directives[id] == 'predict'
      node = self.predicts[id]
    return node

  def add_accepted_proposal(self, hashval):
    if self.mhstats_details:
      if hashval in self.mhstats:
        self.mhstats[hashval]['accepted-proposals'] += 1
      else:
        self.mhstats[hashval] = {}
        self.mhstats[hashval]['accepted-proposals'] = 1
        self.mhstats[hashval]['made-proposals'] = 0
    self.accepted_proposals += 1

  def add_made_proposal(self, hashval):
    if self.mhstats_details:
      if hashval in self.mhstats:
        self.mhstats[hashval]['made-proposals'] += 1
      else:
        self.mhstats[hashval] = {}
        self.mhstats[hashval]['made-proposals'] = 1
        self.mhstats[hashval]['accepted-proposals'] = 0
    self.made_proposals += 1

  def reflip(self, hashval):
    if self.debug:
      print self

    if hashval in self.choices:
      self.application_reflip = True
      self.reflip_node = self.choices[hashval]
      if not self.reflip_node.random_xrp_apply:
        raise RException("Reflipping something which isn't a random xrp application")
      if self.reflip_node.val is None:
        raise RException("Reflipping something which previously had value None")
    else:
      self.application_reflip = False # internal reflip
      (self.reflip_xrp, nodes) = self.xrps[hashval]
      self.nodes = list_nodes(nodes)

    self.eval_p = 0
    self.uneval_p = 0

    old_p = self.p
    self.old_to_new_q = - math.log(self.weight())

    if self.application_reflip:
      self.old_val = self.reflip_node.val
      self.new_val = self.reflip_node.reflip()
    else:
      # TODO: this is copied from traces.  is it correct?
      args_list = []
      self.old_vals = []
      for node in self.nodes:
        args_list.append(node.args)
        self.old_vals.append(node.val)
      self.old_to_new_q += math.log(self.reflip_xrp.state_weight())
      old_p += self.reflip_xrp.theta_prob()
  
      self.new_vals, q_forwards, q_back = self.reflip_xrp.theta_mh_prop(args_list, self.old_vals)
      self.old_to_new_q += q_forwards
      self.new_to_old_q += q_back
  
      for i in range(len(self.nodes)):
        node = self.nodes[i]
        val = self.new_vals[i]
        node.set_val(val)
        node.propagate_up(False)

    new_p = self.p
    self.new_to_old_q = - math.log(self.weight())
    self.old_to_new_q += self.eval_p
    self.new_to_old_q += self.uneval_p

    if not self.application_reflip:
      new_p += self.reflip_xrp.theta_prob()
      self.new_to_old_q += math.log(self.reflip_xrp.state_weight())

    if self.debug:
      if self.application_reflip:
        print "\nCHANGING VAL OF ", self.reflip_node, "\n  FROM  :  ", self.old_val, "\n  TO   :  ", self.new_val, "\n"
        if (self.old_val.__eq__(self.new_val)).bool:
          print "SAME VAL"
      else:
        print "TRANSITIONING STATE OF ", self.reflip_xrp
  
      print "new db", self
      print "\nq(old -> new) : ", math.exp(self.old_to_new_q)
      print "q(new -> old) : ", math.exp(self.new_to_old_q )
      print "p(old) : ", math.exp(old_p)
      print "p(new) : ", math.exp(new_p)
      print 'transition prob : ',  math.exp(new_p + self.new_to_old_q - old_p - self.old_to_new_q) , "\n"
      print "\n-----------------------------------------\n"

    return old_p, self.old_to_new_q, new_p, self.new_to_old_q

  def restore(self):
    if self.application_reflip:
      self.reflip_node.restore(self.old_val)
    else:
      for i in range(len(self.nodes)):
        node = self.nodes[i]
        node.set_val(self.old_vals[i])
        node.propagate_up(True)
      self.reflip_xrp.theta_mh_restore()
    if self.debug:
      print 'restore'

  def keep(self):
    if self.application_reflip:
      self.reflip_node.restore(self.new_val) # NOTE: Is necessary for correctness, as we must forget old branch
    else:
      for i in range(len(self.nodes)):
        node = self.nodes[i]
        node.restore(self.new_vals[i])
      self.reflip_xrp.theta_mh_keep()

  def add_for_transition(self, xrp, evalnode):
    hashval = xrp.__hash__()
    if hashval not in self.xrps:
      self.xrps[hashval] = (xrp, {})
    (xrp, evalnodes) = self.xrps[hashval]
    evalnodes[evalnode] = True

    weight = xrp.state_weight()
    self.delete_from_db(xrp.__hash__())
    self.add_to_db(xrp.__hash__(), weight)

  # Add an XRP application node to the db
  def add_xrp(self, xrp, args, evalnode):
    weight = xrp.weight(args)
    evalnode.setargs(args)
    try:
      self.new_to_old_q += math.log(weight)
    except:
      pass # This is only necessary if we're reflipping

    if self.weighted_db.__contains__(evalnode.hashval) or self.db.__contains__(evalnode.hashval) or (evalnode.hashval in self.choices):
      raise RException("DB already had this evalnode")
    self.choices[evalnode.hashval] = evalnode
    self.add_to_db(evalnode.hashval, weight)

  def add_to_db(self, hashval, weight):
    if weight == 0:
      return
    elif weight == 1:
      self.db.__setitem__(hashval, True)
    else:
      self.weighted_db.__setitem__(hashval, True, weight)


  def remove_for_transition(self, xrp, evalnode):
    hashval = xrp.__hash__()
    (xrp, evalnodes) = self.xrps[hashval]
    del evalnodes[evalnode]
    if len(evalnodes) == 0:
      del self.xrps[hashval]
      self.delete_from_db(xrp.__hash__())

  def remove_xrp(self, evalnode):
    xrp = evalnode.xrp
    try:
      self.old_to_new_q += math.log(xrp.weight(evalnode.args))
    except:
      pass # This fails when restoring/keeping, for example

    if evalnode.hashval not in self.choices:
      raise RException("Choices did not already have this evalnode")
    else:
      del self.choices[evalnode.hashval]
      self.delete_from_db(evalnode.hashval)

  def delete_from_db(self, hashval):
    if self.db.__contains__(hashval):
      self.db.__delitem__(hashval)
    elif self.weighted_db.__contains__(hashval):
      self.weighted_db.__delitem__(hashval)

  def randomKey(self):
    if rrandom.random.random() * self.weight() > self.db.weight():
      return self.weighted_db.randomKey()
    else:
      return self.db.randomKey()

  def infer(self):
    try:
      hashval = self.randomKey()
    except:
      raise RException("Program has no randomness!")

    old_p, old_to_new_q, new_p, new_to_old_q = self.reflip(hashval)
    p = rrandom.random.random()
    if new_p + new_to_old_q - old_p - old_to_new_q < math.log(p):
      self.restore()
    else:
      self.keep()
      self.add_accepted_proposal(hashval)
    self.add_made_proposal(hashval)

  def reset(self):
    self.__init__()
    
  def __str__(self):
    string = "EvalNodeTree:"
    for evalnode in self.assumes.values():
      string += evalnode.str_helper()
    for evalnode in self.observes.values():
      string += evalnode.str_helper()
    for evalnode in self.predicts.values():
      string += evalnode.str_helper()
    return string
Example #6
0
class Traces(Engine):
    def __init__(self):
        self.engine_type = "traces"

        self.assumes = {}  # id -> evalnode
        self.observes = {}  # id -> evalnode
        self.predicts = {}  # id -> evalnode
        self.directives = []

        self.db = RandomChoiceDict()
        self.weighted_db = WeightedRandomChoiceDict()
        self.choices = {}  # hash -> evalnode
        self.xrps = {}  # hash -> (xrp, set of application nodes)

        self.env = EnvironmentNode()

        self.p = 0
        self.old_to_new_q = 0
        self.new_to_old_q = 0

        self.eval_xrps = []  # (xrp, args, val)
        self.uneval_xrps = []  # (xrp, args, val)

        self.debug = False

        # Stuff for restoring
        self.application_reflip = False
        self.reflip_node = EvalNode(self, self.env, VarExpression(""))
        self.nodes = []
        self.old_vals = []
        self.new_vals = []
        self.old_val = Value()
        self.new_val = Value()
        self.reflip_xrp = XRP()

        self.made_proposals = 0
        self.accepted_proposals = 0

        self.mhstats_details = False
        self.mhstats = {}
        return

    def mhstats_on(self):
        self.mhstats_details = True

    def mhstats_off(self):
        self.mhstats_details = False

    def mhstats_aggregated(self):
        d = {}
        d["made-proposals"] = self.made_proposals
        d["accepted-proposals"] = self.accepted_proposals
        return d

    def mhstats_detailed(self):
        return self.mhstats

    def get_log_score(self, id):
        if id == -1:
            return self.p
        else:
            return self.observes[id].p

    def weight(self):
        return self.db.weight() + self.weighted_db.weight()

    def random_choices(self):
        return self.db.__len__() + self.weighted_db.__len__()

    def assume(self, name, expr, id):
        evalnode = EvalNode(self, self.env, expr)

        if id != -1:
            self.assumes[id] = evalnode
            if id != len(self.directives):
                raise RException("Id %d does not agree with directives length of %d" % (id, len(self.directives)))
            self.directives.append("assume")

        val = evalnode.evaluate()
        self.env.add_assume(name, evalnode)
        evalnode.add_assume(name, self.env)
        self.env.set(name, val)
        return val

    def predict(self, expr, id):
        evalnode = EvalNode(self, self.env, expr)

        if id != len(self.directives):
            raise RException("Id %d does not agree with directives length of %d" % (id, len(self.directives)))
        self.directives.append("predict")
        self.predicts[id] = evalnode
        val = evalnode.evaluate(False)
        return val

    def observe(self, expr, obs_val, id):
        evalnode = EvalNode(self, self.env, expr)

        if id != len(self.directives):
            raise RException("Id %d does not agree with directives length of %d" % (id, len(self.directives)))
        self.directives.append("observe")
        self.observes[id] = evalnode

        evalnode.observed = True
        evalnode.observe_val = obs_val

        val = evalnode.evaluate()
        return val

    def forget(self, id):
        if id in self.observes:
            d = self.observes
        elif id in self.predicts:
            d = self.predicts
        else:
            raise RException("Can only forget predicts and observes")
        evalnode = d[id]
        evalnode.unevaluate()
        # del d[id]
        return

    def report_value(self, id):
        node = self.get_directive_node(id)
        if not node.active:
            raise RException("Error.  Perhaps this directive was forgotten?")
        val = node.val
        return val

    def get_directive_node(self, id):
        if self.directives[id] == "assume":
            node = self.assumes[id]
        elif self.directives[id] == "observe":
            node = self.observes[id]
        elif self.directives[id] == "predict":
            node = self.predicts[id]
        else:
            raise RException("Invalid directive")
        return node

    def add_accepted_proposal(self, hashval):
        if self.mhstats_details:
            if hashval in self.mhstats:
                self.mhstats[hashval]["accepted-proposals"] += 1
            else:
                self.mhstats[hashval] = {}
                self.mhstats[hashval]["accepted-proposals"] = 1
                self.mhstats[hashval]["made-proposals"] = 0
        self.accepted_proposals += 1

    def add_made_proposal(self, hashval):
        if self.mhstats_details:
            if hashval in self.mhstats:
                self.mhstats[hashval]["made-proposals"] += 1
            else:
                self.mhstats[hashval] = {}
                self.mhstats[hashval]["made-proposals"] = 1
                self.mhstats[hashval]["accepted-proposals"] = 0
        self.made_proposals += 1

    def reflip(self, hashval):
        if self.debug:
            print self

        if hashval in self.choices:
            self.application_reflip = True
            self.reflip_node = self.choices[hashval]
            if not self.reflip_node.random_xrp_apply:
                raise RException("Reflipping something which isn't a random xrp application")
            if self.reflip_node.val is None:
                raise RException("Reflipping something which previously had value None")
        else:
            self.application_reflip = False  # internal reflip
            (self.reflip_xrp, nodes) = self.xrps[hashval]
            self.nodes = list_nodes(nodes)

        self.old_to_new_q = 0
        self.new_to_old_q = 0

        old_p = self.p

        self.old_to_new_q = -math.log(self.weight())

        if self.application_reflip:
            self.old_val = self.reflip_node.val
            self.new_val = self.reflip_node.reflip()
        else:
            args_list = []
            self.old_vals = []
            for node in self.nodes:
                if not node.xrp_apply:
                    raise RException("non-XRP application node being used in transition")
                args_list.append(node.args)
                self.old_vals.append(node.val)
            self.old_to_new_q += math.log(self.reflip_xrp.state_weight())
            old_p += self.reflip_xrp.theta_prob()

            self.new_vals, q_forwards, q_back = self.reflip_xrp.theta_mh_prop(args_list, self.old_vals)
            self.old_to_new_q += q_forwards
            self.new_to_old_q += q_back

            for i in range(len(self.nodes)):
                node = self.nodes[i]
                val = self.new_vals[i]
                node.set_val(val)
                node.propagate_up(False, True)

        new_p = self.p
        self.new_to_old_q -= math.log(self.weight())

        if not self.application_reflip:
            new_p += self.reflip_xrp.theta_prob()
            self.new_to_old_q += math.log(self.reflip_xrp.state_weight())

        if self.debug:
            if self.application_reflip:
                print "\nCHANGING VAL OF ", self.reflip_node, "\n  FROM  :  ", self.old_val, "\n  TO   :  ", self.new_val, "\n"
                if (self.old_val.__eq__(self.new_val)).bool:
                    print "SAME VAL"
            else:
                print "TRANSITIONING STATE OF ", self.reflip_xrp

            print "new db", self
            print "\nq(old -> new) : ", math.exp(self.old_to_new_q)
            print "q(new -> old) : ", math.exp(self.new_to_old_q)
            print "p(old) : ", math.exp(old_p)
            print "p(new) : ", math.exp(new_p)
            print "transition prob : ", math.exp(new_p + self.new_to_old_q - old_p - self.old_to_new_q), "\n"

        return old_p, self.old_to_new_q, new_p, self.new_to_old_q

    def restore(self):
        if self.application_reflip:
            self.reflip_node.restore(self.old_val)
        else:
            for i in range(len(self.nodes)):
                node = self.nodes[i]
                node.set_val(self.old_vals[i])
                node.propagate_up(True, True)
            self.reflip_xrp.theta_mh_restore()
        if self.debug:
            print "restore"
            print "\n-----------------------------------------\n"

    def keep(self):
        if self.application_reflip:
            self.reflip_node.restore(self.new_val)  # NOTE: Is necessary for correctness, as we must forget old branch
        else:
            for i in range(len(self.nodes)):
                node = self.nodes[i]
                node.restore(self.new_vals[i])
            self.reflip_xrp.theta_mh_keep()
        if self.debug:
            print "keep"
            print "\n-----------------------------------------\n"

    def add_for_transition(self, xrp, evalnode):
        hashval = xrp.__hash__()
        if hashval not in self.xrps:
            self.xrps[hashval] = (xrp, {})
        (xrp, evalnodes) = self.xrps[hashval]
        evalnodes[evalnode] = True

        weight = xrp.state_weight()
        self.delete_from_db(hashval)
        self.add_to_db(hashval, weight)

    # Add an XRP application node to the db
    def add_xrp(self, xrp, args, val, evalnode, forcing=False):
        self.eval_xrps.append((xrp, args, val))

        evalnode.forcing = forcing
        if not xrp.resample:
            evalnode.random_xrp_apply = True
        prob = xrp.prob(val, args)
        evalnode.setargs(xrp, args, prob)
        xrp.incorporate(val, args)
        xrp.make_link(evalnode, args)
        if not forcing:
            self.old_to_new_q += prob
        self.p += prob
        self.add_for_transition(xrp, evalnode)
        if xrp.resample:
            return

        weight = xrp.weight(args)
        try:
            self.new_to_old_q += math.log(weight)
        except:
            pass  # This is only necessary if we're reflipping

        if (
            self.weighted_db.__contains__(evalnode.hashval)
            or self.db.__contains__(evalnode.hashval)
            or (evalnode.hashval in self.choices)
        ):
            raise RException("DB already had this evalnode")
        self.choices[evalnode.hashval] = evalnode
        self.add_to_db(evalnode.hashval, weight)

    def add_to_db(self, hashval, weight):
        if weight == 0:
            return
        elif weight == 1:
            self.db.__setitem__(hashval, True)
        else:
            self.weighted_db.__setitem__(hashval, True, weight)

    def remove_for_transition(self, xrp, evalnode):
        hashval = xrp.__hash__()
        (xrp, evalnodes) = self.xrps[hashval]
        del evalnodes[evalnode]
        if len(evalnodes) == 0:
            del self.xrps[hashval]
            self.delete_from_db(hashval)

    def remove_xrp(self, xrp, args, val, evalnode):
        self.uneval_xrps.append((xrp, args, val))

        xrp.break_link(evalnode, args)
        xrp.remove(val, args)
        prob = xrp.prob(val, args)
        if not evalnode.forcing:
            self.new_to_old_q += prob
        self.p -= prob
        self.remove_for_transition(xrp, evalnode)
        if xrp.resample:
            return

        xrp = evalnode.xrp
        try:  # TODO: dont do this here.. dont do cancellign
            self.old_to_new_q += math.log(xrp.weight(evalnode.args))
        except:
            pass  # This fails when restoring/keeping, for example

        if evalnode.hashval not in self.choices:
            raise RException("Choices did not already have this evalnode")
        else:
            del self.choices[evalnode.hashval]
            self.delete_from_db(evalnode.hashval)

    def delete_from_db(self, hashval):
        if self.db.__contains__(hashval):
            self.db.__delitem__(hashval)
        elif self.weighted_db.__contains__(hashval):
            self.weighted_db.__delitem__(hashval)

    def randomKey(self):
        if rrandom.random.random() * self.weight() > self.db.weight():
            return self.weighted_db.randomKey()
        else:
            return self.db.randomKey()

    def infer(self):
        try:
            hashval = self.randomKey()
        except:
            raise RException("Program has no randomness!")
        old_p, old_to_new_q, new_p, new_to_old_q = self.reflip(hashval)
        p = rrandom.random.random()
        if new_p + self.new_to_old_q - old_p - self.old_to_new_q < math.log(p):
            self.restore()
        else:
            self.keep()
            self.add_accepted_proposal(hashval)
        self.add_made_proposal(hashval)

    def reset(self):
        self.__init__()

    def __str__(self):
        string = "EvalNodeTree:"
        for evalnode in self.assumes.values():
            string += evalnode.str_helper()
        for evalnode in self.observes.values():
            string += evalnode.str_helper()
        for evalnode in self.predicts.values():
            string += evalnode.str_helper()
        return string