Ejemplo n.º 1
0
class ReducedTraces(Engine):
    def __init__(self):
        self.engine_type = 'reduced traces'
        self.assumes = {}  # id -> evalnode
        self.observes = {}  # id -> evalnode
        self.predicts = {}  # id -> evalnode
        self.directives = []

        self.db = RandomChoiceDict()
        self.weighted_db = WeightedRandomChoiceDict()
        self.choices = {}  # hash -> evalnode
        self.xrps = {}  # hash -> (xrp, set of application nodes)

        self.env = EnvironmentNode()

        self.p = 0
        self.uneval_p = 0
        self.eval_p = 0
        self.new_to_old_q = 0
        self.old_to_new_q = 0

        self.debug = False

        # necessary because of the new XRP interface requiring some state kept while doing inference
        self.application_reflip = False
        self.reflip_node = ReducedEvalNode(self, self.env, VarExpression(''))
        self.nodes = []
        self.old_vals = [Value()]
        self.new_vals = [Value()]
        self.old_val = Value()
        self.new_val = Value()
        self.reflip_xrp = XRP()

        self.mhstats_details = False
        self.mhstats = {}
        self.made_proposals = 0
        self.accepted_proposals = 0

        self.hashval = rrandom.random.randbelow()
        return

    def mhstats_on(self):
        self.mhstats_details = True

    def mhstats_off(self):
        self.mhstats_details = False

    def mhstats_aggregated(self):
        d = {}
        d['made-proposals'] = self.made_proposals
        d['accepted-proposals'] = self.accepted_proposals
        return d

    def mhstats_detailed(self):
        return self.mhstats

    def get_log_score(self, id):
        if id == -1:
            return self.p
        else:
            return self.observes[id].p

    def weight(self):
        return self.db.weight() + self.weighted_db.weight()

    def random_choices(self):
        return self.db.__len__() + self.weighted_db.__len__()

    def assume(self, name, expr, id=-1):
        evalnode = ReducedEvalNode(self, self.env, expr)
        self.env.add_assume(name, evalnode)
        evalnode.add_assume(name, self.env)

        if id != -1:
            self.assumes[id] = evalnode
            assert id == len(self.directives)
            self.directives.append('assume')
        val = evalnode.evaluate()
        return val

    def predict(self, expr, id):
        evalnode = ReducedEvalNode(self, self.env, expr)

        assert id == len(self.directives)
        self.directives.append('predict')
        self.predicts[id] = evalnode

        evalnode.predict = True
        val = evalnode.evaluate()
        return val

    def observe(self, expr, obs_val, id):
        evalnode = ReducedEvalNode(self, self.env, expr)

        assert id == len(self.directives)
        self.directives.append('observe')
        self.observes[id] = evalnode

        evalnode.observed = True
        evalnode.observe_val = obs_val

        val = evalnode.evaluate()
        return val

    def forget(self, id):
        if id in self.observes:
            d = self.observes
        elif id in self.predicts:
            d = self.predicts
        else:
            raise RException("Can only forget predicts and observes")
        evalnode = d[id]
        evalnode.unevaluate()
        #del d[id]
        return

    def report_value(self, id):
        node = self.get_directive_node(id)
        if not node.active:
            raise RException("Error.  Perhaps this directive was forgotten?")
        val = node.val
        return val

    def get_directive_node(self, id):
        if self.directives[id] == 'assume':
            node = self.assumes[id]
        elif self.directives[id] == 'observe':
            node = self.observes[id]
        else:
            assert self.directives[id] == 'predict'
            node = self.predicts[id]
        return node

    def add_accepted_proposal(self, hashval):
        if self.mhstats_details:
            if hashval in self.mhstats:
                self.mhstats[hashval]['accepted-proposals'] += 1
            else:
                self.mhstats[hashval] = {}
                self.mhstats[hashval]['accepted-proposals'] = 1
                self.mhstats[hashval]['made-proposals'] = 0
        self.accepted_proposals += 1

    def add_made_proposal(self, hashval):
        if self.mhstats_details:
            if hashval in self.mhstats:
                self.mhstats[hashval]['made-proposals'] += 1
            else:
                self.mhstats[hashval] = {}
                self.mhstats[hashval]['made-proposals'] = 1
                self.mhstats[hashval]['accepted-proposals'] = 0
        self.made_proposals += 1

    def reflip(self, hashval):
        if self.debug:
            print self

        if hashval in self.choices:
            self.application_reflip = True
            self.reflip_node = self.choices[hashval]
            if not self.reflip_node.random_xrp_apply:
                raise RException(
                    "Reflipping something which isn't a random xrp application"
                )
            if self.reflip_node.val is None:
                raise RException(
                    "Reflipping something which previously had value None")
        else:
            self.application_reflip = False  # internal reflip
            (self.reflip_xrp, nodes) = self.xrps[hashval]
            self.nodes = list_nodes(nodes)

        self.eval_p = 0
        self.uneval_p = 0

        old_p = self.p
        self.old_to_new_q = -math.log(self.weight())

        if self.application_reflip:
            self.old_val = self.reflip_node.val
            self.new_val = self.reflip_node.reflip()
        else:
            # TODO: this is copied from traces.  is it correct?
            args_list = []
            self.old_vals = []
            for node in self.nodes:
                args_list.append(node.args)
                self.old_vals.append(node.val)
            self.old_to_new_q += math.log(self.reflip_xrp.state_weight())
            old_p += self.reflip_xrp.theta_prob()

            self.new_vals, q_forwards, q_back = self.reflip_xrp.theta_mh_prop(
                args_list, self.old_vals)
            self.old_to_new_q += q_forwards
            self.new_to_old_q += q_back

            for i in range(len(self.nodes)):
                node = self.nodes[i]
                val = self.new_vals[i]
                node.set_val(val)
                node.propagate_up(False)

        new_p = self.p
        self.new_to_old_q = -math.log(self.weight())
        self.old_to_new_q += self.eval_p
        self.new_to_old_q += self.uneval_p

        if not self.application_reflip:
            new_p += self.reflip_xrp.theta_prob()
            self.new_to_old_q += math.log(self.reflip_xrp.state_weight())

        if self.debug:
            if self.application_reflip:
                print "\nCHANGING VAL OF ", self.reflip_node, "\n  FROM  :  ", self.old_val, "\n  TO   :  ", self.new_val, "\n"
                if (self.old_val.__eq__(self.new_val)).bool:
                    print "SAME VAL"
            else:
                print "TRANSITIONING STATE OF ", self.reflip_xrp

            print "new db", self
            print "\nq(old -> new) : ", math.exp(self.old_to_new_q)
            print "q(new -> old) : ", math.exp(self.new_to_old_q)
            print "p(old) : ", math.exp(old_p)
            print "p(new) : ", math.exp(new_p)
            print 'transition prob : ', math.exp(new_p + self.new_to_old_q -
                                                 old_p -
                                                 self.old_to_new_q), "\n"
            print "\n-----------------------------------------\n"

        return old_p, self.old_to_new_q, new_p, self.new_to_old_q

    def restore(self):
        if self.application_reflip:
            self.reflip_node.restore(self.old_val)
        else:
            for i in range(len(self.nodes)):
                node = self.nodes[i]
                node.set_val(self.old_vals[i])
                node.propagate_up(True)
            self.reflip_xrp.theta_mh_restore()
        if self.debug:
            print 'restore'

    def keep(self):
        if self.application_reflip:
            self.reflip_node.restore(
                self.new_val
            )  # NOTE: Is necessary for correctness, as we must forget old branch
        else:
            for i in range(len(self.nodes)):
                node = self.nodes[i]
                node.restore(self.new_vals[i])
            self.reflip_xrp.theta_mh_keep()

    def add_for_transition(self, xrp, evalnode):
        hashval = xrp.__hash__()
        if hashval not in self.xrps:
            self.xrps[hashval] = (xrp, {})
        (xrp, evalnodes) = self.xrps[hashval]
        evalnodes[evalnode] = True

        weight = xrp.state_weight()
        self.delete_from_db(xrp.__hash__())
        self.add_to_db(xrp.__hash__(), weight)

    # Add an XRP application node to the db
    def add_xrp(self, xrp, args, evalnode):
        weight = xrp.weight(args)
        evalnode.setargs(args)
        try:
            self.new_to_old_q += math.log(weight)
        except:
            pass  # This is only necessary if we're reflipping

        if self.weighted_db.__contains__(
                evalnode.hashval) or self.db.__contains__(
                    evalnode.hashval) or (evalnode.hashval in self.choices):
            raise RException("DB already had this evalnode")
        self.choices[evalnode.hashval] = evalnode
        self.add_to_db(evalnode.hashval, weight)

    def add_to_db(self, hashval, weight):
        if weight == 0:
            return
        elif weight == 1:
            self.db.__setitem__(hashval, True)
        else:
            self.weighted_db.__setitem__(hashval, True, weight)

    def remove_for_transition(self, xrp, evalnode):
        hashval = xrp.__hash__()
        (xrp, evalnodes) = self.xrps[hashval]
        del evalnodes[evalnode]
        if len(evalnodes) == 0:
            del self.xrps[hashval]
            self.delete_from_db(xrp.__hash__())

    def remove_xrp(self, evalnode):
        xrp = evalnode.xrp
        try:
            self.old_to_new_q += math.log(xrp.weight(evalnode.args))
        except:
            pass  # This fails when restoring/keeping, for example

        if evalnode.hashval not in self.choices:
            raise RException("Choices did not already have this evalnode")
        else:
            del self.choices[evalnode.hashval]
            self.delete_from_db(evalnode.hashval)

    def delete_from_db(self, hashval):
        if self.db.__contains__(hashval):
            self.db.__delitem__(hashval)
        elif self.weighted_db.__contains__(hashval):
            self.weighted_db.__delitem__(hashval)

    def randomKey(self):
        if rrandom.random.random() * self.weight() > self.db.weight():
            return self.weighted_db.randomKey()
        else:
            return self.db.randomKey()

    def infer(self):
        try:
            hashval = self.randomKey()
        except:
            raise RException("Program has no randomness!")

        old_p, old_to_new_q, new_p, new_to_old_q = self.reflip(hashval)
        p = rrandom.random.random()
        if new_p + new_to_old_q - old_p - old_to_new_q < math.log(p):
            self.restore()
        else:
            self.keep()
            self.add_accepted_proposal(hashval)
        self.add_made_proposal(hashval)

    def reset(self):
        self.__init__()

    def __str__(self):
        string = "EvalNodeTree:"
        for evalnode in self.assumes.values():
            string += evalnode.str_helper()
        for evalnode in self.observes.values():
            string += evalnode.str_helper()
        for evalnode in self.predicts.values():
            string += evalnode.str_helper()
        return string
class ReducedTraces(Engine):
  def __init__(self):
    self.engine_type = 'reduced traces'
    self.assumes = {} # id -> evalnode
    self.observes = {} # id -> evalnode
    self.predicts = {} # id -> evalnode
    self.directives = []

    self.db = RandomChoiceDict() 
    self.weighted_db = WeightedRandomChoiceDict() 
    self.choices = {} # hash -> evalnode
    self.xrps = {} # hash -> (xrp, set of application nodes)

    self.env = EnvironmentNode()

    self.p = 0
    self.uneval_p = 0
    self.eval_p = 0
    self.new_to_old_q = 0
    self.old_to_new_q = 0

    self.debug = False

    # necessary because of the new XRP interface requiring some state kept while doing inference
    self.application_reflip = False
    self.reflip_node = ReducedEvalNode(self, self.env, VarExpression(''))
    self.nodes = []
    self.old_vals = [Value()]
    self.new_vals = [Value()]
    self.old_val = Value() 
    self.new_val = Value() 
    self.reflip_xrp = XRP()

    self.mhstats_details = False
    self.mhstats = {}
    self.made_proposals = 0
    self.accepted_proposals = 0

    self.hashval = rrandom.random.randbelow()
    return

  def mhstats_on(self):
    self.mhstats_details = True

  def mhstats_off(self):
    self.mhstats_details = False

  def mhstats_aggregated(self):
    d = {}
    d['made-proposals'] = self.made_proposals
    d['accepted-proposals'] = self.accepted_proposals
    return d

  def mhstats_detailed(self):
    return self.mhstats

  def get_log_score(self, id):
    if id == -1:
      return self.p
    else:
      return self.observes[id].p

  def weight(self):
    return self.db.weight() + self.weighted_db.weight()

  def random_choices(self):
    return self.db.__len__() + self.weighted_db.__len__()

  def assume(self, name, expr, id = -1):
    evalnode = ReducedEvalNode(self, self.env, expr)
    self.env.add_assume(name, evalnode)
    evalnode.add_assume(name, self.env)

    if id != -1:
      self.assumes[id] = evalnode
      assert id == len(self.directives)
      self.directives.append('assume')
    val = evalnode.evaluate()
    return val

  def predict(self, expr, id):
    evalnode = ReducedEvalNode(self, self.env, expr)

    assert id == len(self.directives)
    self.directives.append('predict')
    self.predicts[id] = evalnode

    evalnode.predict = True
    val = evalnode.evaluate()
    return val

  def observe(self, expr, obs_val, id):
    evalnode = ReducedEvalNode(self, self.env, expr)

    assert id == len(self.directives)
    self.directives.append('observe')
    self.observes[id] = evalnode

    evalnode.observed = True
    evalnode.observe_val = obs_val

    val = evalnode.evaluate()
    return val

  def forget(self, id):
    if id in self.observes:
      d = self.observes
    elif id in self.predicts:
      d = self.predicts
    else:
      raise RException("Can only forget predicts and observes")
    evalnode = d[id]
    evalnode.unevaluate()
    #del d[id]
    return

  def report_value(self, id):
    node = self.get_directive_node(id)
    if not node.active:
      raise RException("Error.  Perhaps this directive was forgotten?")
    val = node.val
    return val

  def get_directive_node(self, id):
    if self.directives[id] == 'assume':
      node = self.assumes[id]
    elif self.directives[id] == 'observe':
      node = self.observes[id]
    else:
      assert self.directives[id] == 'predict'
      node = self.predicts[id]
    return node

  def add_accepted_proposal(self, hashval):
    if self.mhstats_details:
      if hashval in self.mhstats:
        self.mhstats[hashval]['accepted-proposals'] += 1
      else:
        self.mhstats[hashval] = {}
        self.mhstats[hashval]['accepted-proposals'] = 1
        self.mhstats[hashval]['made-proposals'] = 0
    self.accepted_proposals += 1

  def add_made_proposal(self, hashval):
    if self.mhstats_details:
      if hashval in self.mhstats:
        self.mhstats[hashval]['made-proposals'] += 1
      else:
        self.mhstats[hashval] = {}
        self.mhstats[hashval]['made-proposals'] = 1
        self.mhstats[hashval]['accepted-proposals'] = 0
    self.made_proposals += 1

  def reflip(self, hashval):
    if self.debug:
      print self

    if hashval in self.choices:
      self.application_reflip = True
      self.reflip_node = self.choices[hashval]
      if not self.reflip_node.random_xrp_apply:
        raise RException("Reflipping something which isn't a random xrp application")
      if self.reflip_node.val is None:
        raise RException("Reflipping something which previously had value None")
    else:
      self.application_reflip = False # internal reflip
      (self.reflip_xrp, nodes) = self.xrps[hashval]
      self.nodes = list_nodes(nodes)

    self.eval_p = 0
    self.uneval_p = 0

    old_p = self.p
    self.old_to_new_q = - math.log(self.weight())

    if self.application_reflip:
      self.old_val = self.reflip_node.val
      self.new_val = self.reflip_node.reflip()
    else:
      # TODO: this is copied from traces.  is it correct?
      args_list = []
      self.old_vals = []
      for node in self.nodes:
        args_list.append(node.args)
        self.old_vals.append(node.val)
      self.old_to_new_q += math.log(self.reflip_xrp.state_weight())
      old_p += self.reflip_xrp.theta_prob()
  
      self.new_vals, q_forwards, q_back = self.reflip_xrp.theta_mh_prop(args_list, self.old_vals)
      self.old_to_new_q += q_forwards
      self.new_to_old_q += q_back
  
      for i in range(len(self.nodes)):
        node = self.nodes[i]
        val = self.new_vals[i]
        node.set_val(val)
        node.propagate_up(False)

    new_p = self.p
    self.new_to_old_q = - math.log(self.weight())
    self.old_to_new_q += self.eval_p
    self.new_to_old_q += self.uneval_p

    if not self.application_reflip:
      new_p += self.reflip_xrp.theta_prob()
      self.new_to_old_q += math.log(self.reflip_xrp.state_weight())

    if self.debug:
      if self.application_reflip:
        print "\nCHANGING VAL OF ", self.reflip_node, "\n  FROM  :  ", self.old_val, "\n  TO   :  ", self.new_val, "\n"
        if (self.old_val.__eq__(self.new_val)).bool:
          print "SAME VAL"
      else:
        print "TRANSITIONING STATE OF ", self.reflip_xrp
  
      print "new db", self
      print "\nq(old -> new) : ", math.exp(self.old_to_new_q)
      print "q(new -> old) : ", math.exp(self.new_to_old_q )
      print "p(old) : ", math.exp(old_p)
      print "p(new) : ", math.exp(new_p)
      print 'transition prob : ',  math.exp(new_p + self.new_to_old_q - old_p - self.old_to_new_q) , "\n"
      print "\n-----------------------------------------\n"

    return old_p, self.old_to_new_q, new_p, self.new_to_old_q

  def restore(self):
    if self.application_reflip:
      self.reflip_node.restore(self.old_val)
    else:
      for i in range(len(self.nodes)):
        node = self.nodes[i]
        node.set_val(self.old_vals[i])
        node.propagate_up(True)
      self.reflip_xrp.theta_mh_restore()
    if self.debug:
      print 'restore'

  def keep(self):
    if self.application_reflip:
      self.reflip_node.restore(self.new_val) # NOTE: Is necessary for correctness, as we must forget old branch
    else:
      for i in range(len(self.nodes)):
        node = self.nodes[i]
        node.restore(self.new_vals[i])
      self.reflip_xrp.theta_mh_keep()

  def add_for_transition(self, xrp, evalnode):
    hashval = xrp.__hash__()
    if hashval not in self.xrps:
      self.xrps[hashval] = (xrp, {})
    (xrp, evalnodes) = self.xrps[hashval]
    evalnodes[evalnode] = True

    weight = xrp.state_weight()
    self.delete_from_db(xrp.__hash__())
    self.add_to_db(xrp.__hash__(), weight)

  # Add an XRP application node to the db
  def add_xrp(self, xrp, args, evalnode):
    weight = xrp.weight(args)
    evalnode.setargs(args)
    try:
      self.new_to_old_q += math.log(weight)
    except:
      pass # This is only necessary if we're reflipping

    if self.weighted_db.__contains__(evalnode.hashval) or self.db.__contains__(evalnode.hashval) or (evalnode.hashval in self.choices):
      raise RException("DB already had this evalnode")
    self.choices[evalnode.hashval] = evalnode
    self.add_to_db(evalnode.hashval, weight)

  def add_to_db(self, hashval, weight):
    if weight == 0:
      return
    elif weight == 1:
      self.db.__setitem__(hashval, True)
    else:
      self.weighted_db.__setitem__(hashval, True, weight)


  def remove_for_transition(self, xrp, evalnode):
    hashval = xrp.__hash__()
    (xrp, evalnodes) = self.xrps[hashval]
    del evalnodes[evalnode]
    if len(evalnodes) == 0:
      del self.xrps[hashval]
      self.delete_from_db(xrp.__hash__())

  def remove_xrp(self, evalnode):
    xrp = evalnode.xrp
    try:
      self.old_to_new_q += math.log(xrp.weight(evalnode.args))
    except:
      pass # This fails when restoring/keeping, for example

    if evalnode.hashval not in self.choices:
      raise RException("Choices did not already have this evalnode")
    else:
      del self.choices[evalnode.hashval]
      self.delete_from_db(evalnode.hashval)

  def delete_from_db(self, hashval):
    if self.db.__contains__(hashval):
      self.db.__delitem__(hashval)
    elif self.weighted_db.__contains__(hashval):
      self.weighted_db.__delitem__(hashval)

  def randomKey(self):
    if rrandom.random.random() * self.weight() > self.db.weight():
      return self.weighted_db.randomKey()
    else:
      return self.db.randomKey()

  def infer(self):
    try:
      hashval = self.randomKey()
    except:
      raise RException("Program has no randomness!")

    old_p, old_to_new_q, new_p, new_to_old_q = self.reflip(hashval)
    p = rrandom.random.random()
    if new_p + new_to_old_q - old_p - old_to_new_q < math.log(p):
      self.restore()
    else:
      self.keep()
      self.add_accepted_proposal(hashval)
    self.add_made_proposal(hashval)

  def reset(self):
    self.__init__()
    
  def __str__(self):
    string = "EvalNodeTree:"
    for evalnode in self.assumes.values():
      string += evalnode.str_helper()
    for evalnode in self.observes.values():
      string += evalnode.str_helper()
    for evalnode in self.predicts.values():
      string += evalnode.str_helper()
    return string
Ejemplo n.º 3
0
class RandomDB(Engine):
    def __init__(self):
        #self.db = {}
        self.engine_type = 'randomdb'
        self.db = RandomChoiceDict()
        self.db_noise = {}
        self.log = []
        # ALWAYS WORKING WITH LOG PROBABILITIES
        self.uneval_p = 0
        self.eval_p = 0
        self.p = 0

        self.env = Environment()

        self.assumes = {}
        self.observes = {}
        self.predicts = {}
        self.vars = {}

    def reset(self):
        self.__init__()

    def assume(self, varname, expr, id):
        self.assumes[id] = (varname, expr)
        self.vars[varname] = expr
        value = self.evaluate(expr, self.env, reflip=True, stack=[id])
        self.env.set(varname, value)
        return value

    def observe(self, expr, obs_val, id):
        if expr.hashval in self.observes:
            raise RException('Already observed %s' % str(expr))
        self.observes[id] = (expr, obs_val)
        # bit of a hack, here, to make it recognize same things as with noisy_expr
        self.evaluate(expr,
                      self.env,
                      reflip=False,
                      stack=[id],
                      xrp_force_val=obs_val)
        return expr.hashval

    def predict(self, expr, id):
        self.predicts[id] = expr
        return self.evaluate(expr, self.env, True, [id])

    def forget(self, id):
        self.remove(['obs', id])
        assert id in self.observes
        del self.observes[id]
        return

    def insert(self,
               stack,
               xrp,
               value,
               args,
               is_obs_noise=False,
               memorize=True):
        stack = tuple(stack)
        if self.has(stack):
            self.remove(stack)
        prob = xrp.prob(value, args)
        self.p += prob
        xrp.incorporate(value, args)
        if is_obs_noise:
            self.db_noise[stack] = (xrp, value, args, True)
        else:
            self.db[stack] = (xrp, value, args, False)
        if not is_obs_noise:
            self.eval_p += prob  # hmmm..
        if memorize:
            self.log.append(('insert', stack, xrp, value, args, is_obs_noise))

    def remove(self, stack, memorize=True):
        stack = tuple(stack)
        assert self.has(stack)
        (xrp, value, args, is_obs_noise) = self.get(stack)
        xrp.remove(value, args)
        prob = xrp.prob(value, args)
        self.p -= prob
        if is_obs_noise:
            del self.db_noise[stack]
        else:
            del self.db[stack]
            self.uneval_p += prob  # previously unindented...
        if memorize:
            self.log.append(('remove', stack, xrp, value, args, is_obs_noise))

    def has(self, stack):
        stack = tuple(stack)
        return ((stack in self.db) or (stack in self.db_noise))

    def get(self, stack):
        stack = tuple(stack)
        if stack in self.db:
            return self.db[stack]
        elif stack in self.db_noise:
            return self.db_noise[stack]
        else:
            raise RException('Failed to get stack %s' % str(stack))

    def infer(self):
        try:
            stack = self.random_stack()
        except:
            return
        self.reflip(stack)

    def random_stack(self):
        key = self.db.randomKey()
        return key

    def evaluate_recurse(self, subexpr, env, reflip, stack, addition):
        newstack = [s for s in stack].extend(addition)
        val = self.evaluate(subexpr, env, reflip, newstack)
        return val

    def binary_op_evaluate(self, expr, env, reflip, stack):
        val1 = self.evaluate_recurse(expr.children[0], env, reflip, stack,
                                     ['operand0'])
        val2 = self.evaluate_recurse(expr.children[1], env, reflip, stack,
                                     ['operand1'])
        return (val1, val2)

    def children_evaluate(self, expr, env, reflip, stack):
        return [
            self.evaluate_recurse(expr.children[i], env, reflip, stack,
                                  ['child' + str(i)])
            for i in range(len(expr.children))
        ]

    # Draws a sample value (without re-sampling other values) given its parents, and sets it
    def evaluate(self,
                 expr,
                 env=None,
                 reflip=False,
                 stack=[],
                 xrp_force_val=None):
        if env is None:
            env = self.env

        if xrp_force_val is not None:
            assert expr.type == 'apply'

        if expr.type == 'value':
            val = expr.val
        elif expr.type == 'variable':
            var = expr.name
            (val, lookup_env) = env.lookup(var)
        elif expr.type == 'if':
            cond = self.evaluate_recurse(expr.cond, env, reflip, stack,
                                         ['cond'])
            if cond.bool:
                self.unevaluate(stack + ['false'])
                val = self.evaluate_recurse(expr.true, env, reflip, stack,
                                            ['true'])
            else:
                self.unevaluate(stack + ['true'])
                val = self.evaluate_recurse(expr.false, env, reflip, stack,
                                            ['false'])
        #elif expr.type == 'switch':
        #  index = self.evaluate_recurse(expr.index, env, reflip, stack , ['index'])
        #  assert 0 <= index.num < expr.n
        #  for i in range(expr.n):
        #    if i != index.num:
        #      self.unevaluate(stack + ['child' + str(i)])
        #  val = self.evaluate_recurse(expr.children[index.num], env, reflip, stack, ['child' + str(index.num)])
        elif expr.type == 'let':
            # NOTE: this really is a let*

            n = len(expr.vars)
            assert len(expr.expressions) == n
            values = []
            new_env = env
            for i in range(n):  # Bind variables
                new_env = new_env.spawn_child()
                val = self.evaluate_recurse(expr.expressions[i], new_env,
                                            reflip, stack, ['let' + str(i)])
                values.append(val)
                new_env.set(expr.vars[i], values[i])
                if val.type == 'procedure':
                    val.env = new_env
            new_body = expr.body.replace(new_env)
            val = self.evaluate_recurse(new_body, new_env, reflip, stack,
                                        ['body'])
        elif expr.type == 'apply':
            n = len(expr.children)
            args = [
                self.evaluate_recurse(expr.children[i], env, reflip, stack,
                                      ['arg' + str(i)]) for i in range(n)
            ]
            op = self.evaluate_recurse(expr.op, env, reflip, stack,
                                       ['operator'])

            addition = ','.join([x.str_hash for x in args])
            if op.type == 'procedure':
                self.unevaluate(stack + ['apply'], addition)
                if n != len(op.vars):
                    raise RException(
                        'Procedure should have %d arguments.  \nVars were \n%s\n, but children were \n%s.'
                        % (n, op.vars, expr.children))
                new_env = op.env.spawn_child()
                for i in range(n):
                    new_env.set(op.vars[i], args[i])
                val = self.evaluate_recurse(op.body, new_env, reflip, stack,
                                            ['apply', addition])
            elif op.type == 'xrp':
                self.unevaluate(stack + ['apply'], addition)

                if xrp_force_val is not None:
                    assert not reflip
                    if self.has(stack):
                        self.remove(stack)
                    self.insert(stack, op.xrp, xrp_force_val, args, True)
                    val = xrp_force_val
                else:
                    substack = stack + ['apply', addition]
                    if not self.has(substack):
                        if op.xrp.is_mem_proc():
                            val = op.xrp.apply_mem(args, self, stack)
                        else:
                            val = op.xrp.apply(args)
                        self.insert(substack, op.xrp, val, args)
                    else:
                        if reflip:
                            self.remove(substack)
                            if op.xrp.is_mem_proc():
                                val = op.xrp.apply_mem(args, self, stack)
                            else:
                                val = op.xrp.apply(args)
                            self.insert(substack, op.xrp, val, args)
                        else:
                            (xrp, val, dbargs,
                             is_obs_noise) = self.get(substack)
                            assert not is_obs_noise
            else:
                raise RException(
                    'Must apply either a procedure or xrp.  Instead got expression %s'
                    % str(op))
        elif expr.type == 'function':
            n = len(expr.vars)
            new_env = env.spawn_child()
            bound = {}
            for i in range(n):  # Bind variables
                bound[expr.vars[i]] = True
            procedure_body = expr.body.replace(new_env, bound)
            val = Procedure(expr.vars, procedure_body, env)
        elif expr.type == '=':
            (val1, val2) = self.binary_op_evaluate(expr, env, reflip, stack)
            val = val1.__eq__(val2)
        elif expr.type == '<':
            (val1, val2) = self.binary_op_evaluate(expr, env, reflip, stack)
            val = val1.__lt__(val2)
        elif expr.type == '>':
            (val1, val2) = self.binary_op_evaluate(expr, env, reflip, stack)
            val = val1.__gt__(val2)
        elif expr.type == '<=':
            (val1, val2) = self.binary_op_evaluate(expr, env, reflip, stack)
            val = val1.__le__(val2)
        elif expr.type == '>=':
            (val1, val2) = self.binary_op_evaluate(expr, env, reflip, stack)
            val = val1.__ge__(val2)
        elif expr.type == '&':
            vals = self.children_evaluate(expr, env, reflip, stack)
            andval = BoolValue(True)
            for x in vals:
                andval = andval.__and__(x.bool)
            val = andval
        elif expr.type == '^':
            vals = self.children_evaluate(expr, env, reflip, stack)
            xorval = BoolValue(True)
            for x in vals:
                xorval = xorval.__xor__(x.bool)
            val = xorval
        elif expr.type == '|':
            vals = self.children_evaluate(expr, env, reflip, stack)
            orval = BoolValue(False)
            for x in vals:
                orval = orval.__or__(x.bool)
            val = orval
        elif expr.type == '~':
            negval = self.evaluate_recurse(expr.children[0], env, reflip,
                                           stack, ['neg'])
            val = negval.__inv__()
        elif expr.type == '+':
            vals = self.children_evaluate(expr, env, reflip, stack)
            sum_val = NatValue(0)
            for x in vals:
                sum_val = sum_val.__add__(x)
            val = sum_val
        elif expr.type == '-':
            val1 = self.evaluate_recurse(expr.children[0], env, reflip, stack,
                                         ['sub0'])
            val2 = self.evaluate_recurse(expr.children[1], env, reflip, stack,
                                         ['sub1'])
            val = val1.__sub__(val2)
        elif expr.type == '*':
            vals = self.children_evaluate(expr, env, reflip, stack)
            prod_val = NatValue(1)
            for x in vals:
                prod_val = prod_val.__mul__(x)
            val = prod_val
        elif expr.type == '/':
            val1 = self.evaluate_recurse(expr.children[0], env, reflip, stack,
                                         ['div0'])
            val2 = self.evaluate_recurse(expr.children[1], env, reflip, stack,
                                         ['div1'])
            val = val1.__div__(val2)
        else:
            raise RException('Invalid expression type %s' % expr.type)
        return val

    def unevaluate(self, uneval_stack, args=None):
        if args is not None:
            args = tuple(args)

        to_unevaluate = []

        for tuple_stack in self.db:
            to_unevaluate.append(tuple_stack)
        for tuple_stack in self.db_noise:
            to_unevaluate.append(tuple_stack)

        to_delete = []

        for tuple_stack in to_unevaluate:
            stack = list(tuple_stack)
            if len(stack) >= len(uneval_stack) and stack[:len(
                    uneval_stack)] == uneval_stack:
                if args is None:
                    to_delete.append(tuple_stack)
                else:
                    assert len(stack) > len(uneval_stack)
                    if stack[len(uneval_stack)] != args:
                        to_delete.append(tuple_stack)

        for tuple_stack in to_delete:
            self.remove(tuple_stack)

    def save(self):
        self.log = []
        self.uneval_p = 0
        self.eval_p = 0

    def restore(self):
        self.log.reverse()
        for (type, stack, xrp, value, args, is_obs_noise) in self.log:
            if type == 'insert':
                self.remove(stack, False)
            else:
                assert type == 'remove'
                self.insert(stack, xrp, value, args, is_obs_noise, False)

    def reflip(self, stack):
        (xrp, val, args, is_obs_noise) = self.get(stack)

        #debug = True
        debug = False

        old_p = self.p
        old_to_new_q = -math.log(len(self.db))
        if debug:
            print "old_db", self

        self.save()

        self.remove(stack)
        if xrp.is_mem_proc():
            new_val = xrp.apply(args, list(stack))
        else:
            new_val = xrp.apply(args)
        self.insert(stack, xrp, new_val, args)

        if debug:
            print "\nCHANGING ", stack, "\n  TO   :  ", new_val, "\n"

        if val == new_val:
            return

        self.rerun(False)
        new_p = self.p
        new_to_old_q = self.uneval_p - math.log(len(self.db))
        old_to_new_q += self.eval_p
        if debug:
            print "new db", self, \
                  "\nq(old -> new) : ", old_to_new_q, \
                  "q(new -> old) : ", new_to_old_q, \
                  "p(old) : ", old_p, \
                  "p(new) : ", new_p, \
                  'log transition prob : ',  new_p + new_to_old_q - old_p - old_to_new_q , "\n"

        if old_p * old_to_new_q > 0:
            p = rrandom.random.random()
            if new_p + new_to_old_q - old_p - old_to_new_q < math.log(p):
                self.restore()
                if debug:
                    print 'restore'
                self.rerun(False)

        if debug:
            print "new db", self
            print "\n-----------------------------------------\n"

    def __str__(self):
        string = 'DB with state:'
        string += '\n  Regular Flips:'
        for s in self.db:
            string += '\n    %s <- %s' % (self.db[s][1].val, s)
        string += '\n  Observe Flips:'
        for s in self.db_noise:
            string += '\n    %s <- %s' % (self.db_noise[s][1].val, s)
        return string

    def __contains__(self, stack):
        return self.has(self, stack)

    def __getitem__(self, stack):
        return self.get(self, stack)
Ejemplo n.º 4
0
class Traces(Engine):
    def __init__(self):
        self.engine_type = "traces"

        self.assumes = {}  # id -> evalnode
        self.observes = {}  # id -> evalnode
        self.predicts = {}  # id -> evalnode
        self.directives = []

        self.db = RandomChoiceDict()
        self.weighted_db = WeightedRandomChoiceDict()
        self.choices = {}  # hash -> evalnode
        self.xrps = {}  # hash -> (xrp, set of application nodes)

        self.env = EnvironmentNode()

        self.p = 0
        self.old_to_new_q = 0
        self.new_to_old_q = 0

        self.eval_xrps = []  # (xrp, args, val)
        self.uneval_xrps = []  # (xrp, args, val)

        self.debug = False

        # Stuff for restoring
        self.application_reflip = False
        self.reflip_node = EvalNode(self, self.env, VarExpression(""))
        self.nodes = []
        self.old_vals = []
        self.new_vals = []
        self.old_val = Value()
        self.new_val = Value()
        self.reflip_xrp = XRP()

        self.made_proposals = 0
        self.accepted_proposals = 0

        self.mhstats_details = False
        self.mhstats = {}
        return

    def mhstats_on(self):
        self.mhstats_details = True

    def mhstats_off(self):
        self.mhstats_details = False

    def mhstats_aggregated(self):
        d = {}
        d["made-proposals"] = self.made_proposals
        d["accepted-proposals"] = self.accepted_proposals
        return d

    def mhstats_detailed(self):
        return self.mhstats

    def get_log_score(self, id):
        if id == -1:
            return self.p
        else:
            return self.observes[id].p

    def weight(self):
        return self.db.weight() + self.weighted_db.weight()

    def random_choices(self):
        return self.db.__len__() + self.weighted_db.__len__()

    def assume(self, name, expr, id):
        evalnode = EvalNode(self, self.env, expr)

        if id != -1:
            self.assumes[id] = evalnode
            if id != len(self.directives):
                raise RException("Id %d does not agree with directives length of %d" % (id, len(self.directives)))
            self.directives.append("assume")

        val = evalnode.evaluate()
        self.env.add_assume(name, evalnode)
        evalnode.add_assume(name, self.env)
        self.env.set(name, val)
        return val

    def predict(self, expr, id):
        evalnode = EvalNode(self, self.env, expr)

        if id != len(self.directives):
            raise RException("Id %d does not agree with directives length of %d" % (id, len(self.directives)))
        self.directives.append("predict")
        self.predicts[id] = evalnode
        val = evalnode.evaluate(False)
        return val

    def observe(self, expr, obs_val, id):
        evalnode = EvalNode(self, self.env, expr)

        if id != len(self.directives):
            raise RException("Id %d does not agree with directives length of %d" % (id, len(self.directives)))
        self.directives.append("observe")
        self.observes[id] = evalnode

        evalnode.observed = True
        evalnode.observe_val = obs_val

        val = evalnode.evaluate()
        return val

    def forget(self, id):
        if id in self.observes:
            d = self.observes
        elif id in self.predicts:
            d = self.predicts
        else:
            raise RException("Can only forget predicts and observes")
        evalnode = d[id]
        evalnode.unevaluate()
        # del d[id]
        return

    def report_value(self, id):
        node = self.get_directive_node(id)
        if not node.active:
            raise RException("Error.  Perhaps this directive was forgotten?")
        val = node.val
        return val

    def get_directive_node(self, id):
        if self.directives[id] == "assume":
            node = self.assumes[id]
        elif self.directives[id] == "observe":
            node = self.observes[id]
        elif self.directives[id] == "predict":
            node = self.predicts[id]
        else:
            raise RException("Invalid directive")
        return node

    def add_accepted_proposal(self, hashval):
        if self.mhstats_details:
            if hashval in self.mhstats:
                self.mhstats[hashval]["accepted-proposals"] += 1
            else:
                self.mhstats[hashval] = {}
                self.mhstats[hashval]["accepted-proposals"] = 1
                self.mhstats[hashval]["made-proposals"] = 0
        self.accepted_proposals += 1

    def add_made_proposal(self, hashval):
        if self.mhstats_details:
            if hashval in self.mhstats:
                self.mhstats[hashval]["made-proposals"] += 1
            else:
                self.mhstats[hashval] = {}
                self.mhstats[hashval]["made-proposals"] = 1
                self.mhstats[hashval]["accepted-proposals"] = 0
        self.made_proposals += 1

    def reflip(self, hashval):
        if self.debug:
            print self

        if hashval in self.choices:
            self.application_reflip = True
            self.reflip_node = self.choices[hashval]
            if not self.reflip_node.random_xrp_apply:
                raise RException("Reflipping something which isn't a random xrp application")
            if self.reflip_node.val is None:
                raise RException("Reflipping something which previously had value None")
        else:
            self.application_reflip = False  # internal reflip
            (self.reflip_xrp, nodes) = self.xrps[hashval]
            self.nodes = list_nodes(nodes)

        self.old_to_new_q = 0
        self.new_to_old_q = 0

        old_p = self.p

        self.old_to_new_q = -math.log(self.weight())

        if self.application_reflip:
            self.old_val = self.reflip_node.val
            self.new_val = self.reflip_node.reflip()
        else:
            args_list = []
            self.old_vals = []
            for node in self.nodes:
                if not node.xrp_apply:
                    raise RException("non-XRP application node being used in transition")
                args_list.append(node.args)
                self.old_vals.append(node.val)
            self.old_to_new_q += math.log(self.reflip_xrp.state_weight())
            old_p += self.reflip_xrp.theta_prob()

            self.new_vals, q_forwards, q_back = self.reflip_xrp.theta_mh_prop(args_list, self.old_vals)
            self.old_to_new_q += q_forwards
            self.new_to_old_q += q_back

            for i in range(len(self.nodes)):
                node = self.nodes[i]
                val = self.new_vals[i]
                node.set_val(val)
                node.propagate_up(False, True)

        new_p = self.p
        self.new_to_old_q -= math.log(self.weight())

        if not self.application_reflip:
            new_p += self.reflip_xrp.theta_prob()
            self.new_to_old_q += math.log(self.reflip_xrp.state_weight())

        if self.debug:
            if self.application_reflip:
                print "\nCHANGING VAL OF ", self.reflip_node, "\n  FROM  :  ", self.old_val, "\n  TO   :  ", self.new_val, "\n"
                if (self.old_val.__eq__(self.new_val)).bool:
                    print "SAME VAL"
            else:
                print "TRANSITIONING STATE OF ", self.reflip_xrp

            print "new db", self
            print "\nq(old -> new) : ", math.exp(self.old_to_new_q)
            print "q(new -> old) : ", math.exp(self.new_to_old_q)
            print "p(old) : ", math.exp(old_p)
            print "p(new) : ", math.exp(new_p)
            print "transition prob : ", math.exp(new_p + self.new_to_old_q - old_p - self.old_to_new_q), "\n"

        return old_p, self.old_to_new_q, new_p, self.new_to_old_q

    def restore(self):
        if self.application_reflip:
            self.reflip_node.restore(self.old_val)
        else:
            for i in range(len(self.nodes)):
                node = self.nodes[i]
                node.set_val(self.old_vals[i])
                node.propagate_up(True, True)
            self.reflip_xrp.theta_mh_restore()
        if self.debug:
            print "restore"
            print "\n-----------------------------------------\n"

    def keep(self):
        if self.application_reflip:
            self.reflip_node.restore(self.new_val)  # NOTE: Is necessary for correctness, as we must forget old branch
        else:
            for i in range(len(self.nodes)):
                node = self.nodes[i]
                node.restore(self.new_vals[i])
            self.reflip_xrp.theta_mh_keep()
        if self.debug:
            print "keep"
            print "\n-----------------------------------------\n"

    def add_for_transition(self, xrp, evalnode):
        hashval = xrp.__hash__()
        if hashval not in self.xrps:
            self.xrps[hashval] = (xrp, {})
        (xrp, evalnodes) = self.xrps[hashval]
        evalnodes[evalnode] = True

        weight = xrp.state_weight()
        self.delete_from_db(hashval)
        self.add_to_db(hashval, weight)

    # Add an XRP application node to the db
    def add_xrp(self, xrp, args, val, evalnode, forcing=False):
        self.eval_xrps.append((xrp, args, val))

        evalnode.forcing = forcing
        if not xrp.resample:
            evalnode.random_xrp_apply = True
        prob = xrp.prob(val, args)
        evalnode.setargs(xrp, args, prob)
        xrp.incorporate(val, args)
        xrp.make_link(evalnode, args)
        if not forcing:
            self.old_to_new_q += prob
        self.p += prob
        self.add_for_transition(xrp, evalnode)
        if xrp.resample:
            return

        weight = xrp.weight(args)
        try:
            self.new_to_old_q += math.log(weight)
        except:
            pass  # This is only necessary if we're reflipping

        if (
            self.weighted_db.__contains__(evalnode.hashval)
            or self.db.__contains__(evalnode.hashval)
            or (evalnode.hashval in self.choices)
        ):
            raise RException("DB already had this evalnode")
        self.choices[evalnode.hashval] = evalnode
        self.add_to_db(evalnode.hashval, weight)

    def add_to_db(self, hashval, weight):
        if weight == 0:
            return
        elif weight == 1:
            self.db.__setitem__(hashval, True)
        else:
            self.weighted_db.__setitem__(hashval, True, weight)

    def remove_for_transition(self, xrp, evalnode):
        hashval = xrp.__hash__()
        (xrp, evalnodes) = self.xrps[hashval]
        del evalnodes[evalnode]
        if len(evalnodes) == 0:
            del self.xrps[hashval]
            self.delete_from_db(hashval)

    def remove_xrp(self, xrp, args, val, evalnode):
        self.uneval_xrps.append((xrp, args, val))

        xrp.break_link(evalnode, args)
        xrp.remove(val, args)
        prob = xrp.prob(val, args)
        if not evalnode.forcing:
            self.new_to_old_q += prob
        self.p -= prob
        self.remove_for_transition(xrp, evalnode)
        if xrp.resample:
            return

        xrp = evalnode.xrp
        try:  # TODO: dont do this here.. dont do cancellign
            self.old_to_new_q += math.log(xrp.weight(evalnode.args))
        except:
            pass  # This fails when restoring/keeping, for example

        if evalnode.hashval not in self.choices:
            raise RException("Choices did not already have this evalnode")
        else:
            del self.choices[evalnode.hashval]
            self.delete_from_db(evalnode.hashval)

    def delete_from_db(self, hashval):
        if self.db.__contains__(hashval):
            self.db.__delitem__(hashval)
        elif self.weighted_db.__contains__(hashval):
            self.weighted_db.__delitem__(hashval)

    def randomKey(self):
        if rrandom.random.random() * self.weight() > self.db.weight():
            return self.weighted_db.randomKey()
        else:
            return self.db.randomKey()

    def infer(self):
        try:
            hashval = self.randomKey()
        except:
            raise RException("Program has no randomness!")
        old_p, old_to_new_q, new_p, new_to_old_q = self.reflip(hashval)
        p = rrandom.random.random()
        if new_p + self.new_to_old_q - old_p - self.old_to_new_q < math.log(p):
            self.restore()
        else:
            self.keep()
            self.add_accepted_proposal(hashval)
        self.add_made_proposal(hashval)

    def reset(self):
        self.__init__()

    def __str__(self):
        string = "EvalNodeTree:"
        for evalnode in self.assumes.values():
            string += evalnode.str_helper()
        for evalnode in self.observes.values():
            string += evalnode.str_helper()
        for evalnode in self.predicts.values():
            string += evalnode.str_helper()
        return string
Ejemplo n.º 5
0
class RandomDB(Engine):
  def __init__(self):
    #self.db = {} 
    self.engine_type = 'randomdb'
    self.db = RandomChoiceDict() 
    self.db_noise = {}
    self.log = []
    # ALWAYS WORKING WITH LOG PROBABILITIES
    self.uneval_p = 0
    self.eval_p = 0
    self.p = 0 

    self.env = Environment()

    self.assumes = {}
    self.observes = {}
    self.predicts = {}
    self.vars = {}

  def reset(self):
    self.__init__()

  def assume(self, varname, expr, id): 
    self.assumes[id] = (varname, expr)
    self.vars[varname] = expr
    value = self.evaluate(expr, self.env, reflip = True, stack = [id])
    self.env.set(varname, value)
    return value

  def observe(self, expr, obs_val, id):
    if expr.hashval in self.observes:
      raise RException('Already observed %s' % str(expr))
    self.observes[id] = (expr, obs_val) 
    # bit of a hack, here, to make it recognize same things as with noisy_expr
    self.evaluate(expr, self.env, reflip = False, stack = [id], xrp_force_val = obs_val)
    return expr.hashval

  def predict(self, expr, id):
    self.predicts[id] = expr
    return self.evaluate(expr, self.env, True, [id])

  def forget(self, id):
    self.remove(['obs', id])
    assert id in self.observes
    del self.observes[id]
    return

  def insert(self, stack, xrp, value, args, is_obs_noise = False, memorize = True):
    stack = tuple(stack)
    if self.has(stack):
      self.remove(stack)
    prob = xrp.prob(value, args)
    self.p += prob
    xrp.incorporate(value, args)
    if is_obs_noise:
      self.db_noise[stack] = (xrp, value, args, True)
    else:
      self.db[stack] = (xrp, value, args, False)
    if not is_obs_noise:
      self.eval_p += prob # hmmm.. 
    if memorize:
      self.log.append(('insert', stack, xrp, value, args, is_obs_noise))

  def remove(self, stack, memorize = True):
    stack = tuple(stack)
    assert self.has(stack)
    (xrp, value, args, is_obs_noise) = self.get(stack)
    xrp.remove(value, args)
    prob = xrp.prob(value, args)
    self.p -= prob
    if is_obs_noise:
      del self.db_noise[stack]
    else:
      del self.db[stack]
      self.uneval_p += prob # previously unindented...
    if memorize:
      self.log.append(('remove', stack, xrp, value, args, is_obs_noise))

  def has(self, stack):
    stack = tuple(stack)
    return ((stack in self.db) or (stack in self.db_noise)) 

  def get(self, stack):
    stack = tuple(stack)
    if stack in self.db:
      return self.db[stack]
    elif stack in self.db_noise:
      return self.db_noise[stack]
    else:
      raise RException('Failed to get stack %s' % str(stack))

  def infer(self):
    try:
      stack = self.random_stack()
    except:
      return
    self.reflip(stack)

  def random_stack(self):
    key = self.db.randomKey()
    return key

  def evaluate_recurse(self, subexpr, env, reflip, stack, addition):
    newstack = [s for s in stack].extend(addition)
    val = self.evaluate(subexpr, env, reflip, newstack)
    return val
      
  def binary_op_evaluate(self, expr, env, reflip, stack):
    val1 = self.evaluate_recurse(expr.children[0], env, reflip, stack, ['operand0'])
    val2 = self.evaluate_recurse(expr.children[1], env, reflip, stack, ['operand1'])
    return (val1 , val2)
  
  def children_evaluate(self, expr, env, reflip, stack):
    return [self.evaluate_recurse(expr.children[i], env, reflip, stack, ['child' + str(i)]) for i in range(len(expr.children))]

  # Draws a sample value (without re-sampling other values) given its parents, and sets it
  def evaluate(self, expr, env = None, reflip = False, stack = [], xrp_force_val = None):
    if env is None:
      env = self.env
      
    if xrp_force_val is not None: 
      assert expr.type == 'apply'
      
    if expr.type == 'value':
      val = expr.val
    elif expr.type == 'variable':
      var = expr.name
      (val, lookup_env) = env.lookup(var)
    elif expr.type == 'if':
      cond = self.evaluate_recurse(expr.cond, env, reflip, stack , ['cond'])
      if cond.bool: 
        self.unevaluate(stack + ['false'])
        val = self.evaluate_recurse(expr.true, env, reflip, stack , ['true'])
      else:
        self.unevaluate(stack + ['true'])
        val = self.evaluate_recurse(expr.false, env, reflip, stack , ['false'])
    #elif expr.type == 'switch':
    #  index = self.evaluate_recurse(expr.index, env, reflip, stack , ['index'])
    #  assert 0 <= index.num < expr.n
    #  for i in range(expr.n):
    #    if i != index.num:
    #      self.unevaluate(stack + ['child' + str(i)])
    #  val = self.evaluate_recurse(expr.children[index.num], env, reflip, stack, ['child' + str(index.num)])
    elif expr.type == 'let':
      # NOTE: this really is a let*

      n = len(expr.vars)
      assert len(expr.expressions) == n
      values = []
      new_env = env
      for i in range(n): # Bind variables
        new_env = new_env.spawn_child()
        val = self.evaluate_recurse(expr.expressions[i], new_env, reflip, stack, ['let' + str(i)])
        values.append(val)
        new_env.set(expr.vars[i], values[i])
        if val.type == 'procedure':
          val.env = new_env
      new_body = expr.body.replace(new_env)
      val = self.evaluate_recurse(new_body, new_env, reflip, stack, ['body'])
    elif expr.type == 'apply':
      n = len(expr.children)
      args = [self.evaluate_recurse(expr.children[i], env, reflip, stack, ['arg' + str(i)]) for i in range(n)]
      op = self.evaluate_recurse(expr.op, env, reflip, stack , ['operator'])

      addition = ','.join([x.str_hash for x in args])
      if op.type == 'procedure':
        self.unevaluate(stack + ['apply'], addition)
        if n != len(op.vars):
          raise RException('Procedure should have %d arguments.  \nVars were \n%s\n, but children were \n%s.' % (n, op.vars, expr.children))
        new_env = op.env.spawn_child()
        for i in range(n):
          new_env.set(op.vars[i], args[i])
        val = self.evaluate_recurse(op.body, new_env, reflip, stack, ['apply', addition])
      elif op.type == 'xrp':
        self.unevaluate(stack + ['apply'], addition)
  
        if xrp_force_val is not None:
          assert not reflip
          if self.has(stack):
            self.remove(stack)
          self.insert(stack, op.xrp, xrp_force_val, args, True)
          val = xrp_force_val
        else:
          substack = stack + ['apply', addition]
          if not self.has(substack):
            if op.xrp.is_mem_proc():
              val = op.xrp.apply_mem(args, self, stack)
            else:
              val = op.xrp.apply(args)
            self.insert(substack, op.xrp, val, args)
          else:
            if reflip:
              self.remove(substack)
              if op.xrp.is_mem_proc():
                val = op.xrp.apply_mem(args, self, stack)
              else:
                val = op.xrp.apply(args)
              self.insert(substack, op.xrp, val, args)
            else:
              (xrp, val, dbargs, is_obs_noise) = self.get(substack)
              assert not is_obs_noise
      else:
        raise RException('Must apply either a procedure or xrp.  Instead got expression %s' % str(op))
    elif expr.type == 'function':
      n = len(expr.vars)
      new_env = env.spawn_child()
      bound = {}
      for i in range(n): # Bind variables
        bound[expr.vars[i]] = True
      procedure_body = expr.body.replace(new_env, bound)
      val = Procedure(expr.vars, procedure_body, env)
    elif expr.type == '=':
      (val1, val2) = self.binary_op_evaluate(expr, env, reflip, stack)
      val = val1.__eq__(val2)
    elif expr.type == '<':
      (val1, val2) = self.binary_op_evaluate(expr, env, reflip, stack)
      val = val1.__lt__(val2)
    elif expr.type == '>':
      (val1, val2) = self.binary_op_evaluate(expr, env, reflip, stack)
      val = val1.__gt__(val2)
    elif expr.type == '<=':
      (val1, val2) = self.binary_op_evaluate(expr, env, reflip, stack)
      val = val1.__le__(val2)
    elif expr.type == '>=':
      (val1, val2) = self.binary_op_evaluate(expr, env, reflip, stack)
      val = val1.__ge__(val2)
    elif expr.type == '&':
      vals = self.children_evaluate(expr, env, reflip, stack)
      andval = BoolValue(True)
      for x in vals:
        andval = andval.__and__(x.bool)
      val = andval
    elif expr.type == '^':
      vals = self.children_evaluate(expr, env, reflip, stack)
      xorval = BoolValue(True)
      for x in vals:
        xorval = xorval.__xor__(x.bool)
      val = xorval
    elif expr.type == '|':
      vals = self.children_evaluate(expr, env, reflip, stack)
      orval = BoolValue(False)
      for x in vals:
        orval = orval.__or__(x.bool)
      val = orval
    elif expr.type == '~':
      negval = self.evaluate_recurse(expr.children[0], env, reflip, stack, ['neg'])
      val = negval.__inv__()
    elif expr.type == '+':
      vals = self.children_evaluate(expr, env, reflip, stack)
      sum_val = NatValue(0)
      for x in vals:
        sum_val = sum_val.__add__(x)
      val = sum_val
    elif expr.type == '-':
      val1 = self.evaluate_recurse(expr.children[0], env, reflip, stack , ['sub0'])
      val2 = self.evaluate_recurse(expr.children[1], env, reflip, stack , ['sub1'])
      val = val1.__sub__(val2)
    elif expr.type == '*':
      vals = self.children_evaluate(expr, env, reflip, stack)
      prod_val = NatValue(1)
      for x in vals:
        prod_val = prod_val.__mul__(x)
      val = prod_val
    elif expr.type == '/':
      val1 = self.evaluate_recurse(expr.children[0], env, reflip, stack , ['div0'])
      val2 = self.evaluate_recurse(expr.children[1], env, reflip, stack , ['div1'])
      val = val1.__div__(val2)
    else:
      raise RException('Invalid expression type %s' % expr.type)
    return val


  def unevaluate(self, uneval_stack, args = None):
    if args is not None:
      args = tuple(args)

    to_unevaluate = []

    for tuple_stack in self.db:
      to_unevaluate.append(tuple_stack)
    for tuple_stack in self.db_noise:
      to_unevaluate.append(tuple_stack)

    to_delete = []

    for tuple_stack in to_unevaluate:
      stack = list(tuple_stack) 
      if len(stack) >= len(uneval_stack) and stack[:len(uneval_stack)] == uneval_stack:
        if args is None:
          to_delete.append(tuple_stack)
        else:
          assert len(stack) > len(uneval_stack)
          if stack[len(uneval_stack)] != args:
            to_delete.append(tuple_stack)

    for tuple_stack in to_delete:
      self.remove(tuple_stack)

  def save(self):
    self.log = []
    self.uneval_p = 0
    self.eval_p = 0

  def restore(self):
    self.log.reverse()
    for (type, stack, xrp, value, args, is_obs_noise) in self.log:
      if type == 'insert':
        self.remove(stack, False)
      else:
        assert type == 'remove'
        self.insert(stack, xrp, value, args, is_obs_noise, False)

  def reflip(self, stack):
    (xrp, val, args, is_obs_noise) = self.get(stack)
  
    #debug = True 
    debug = False
  
    old_p = self.p
    old_to_new_q = - math.log(len(self.db))
    if debug:
      print  "old_db", self 
  
    self.save()
  
    self.remove(stack)
    if xrp.is_mem_proc():
      new_val = xrp.apply(args, list(stack))
    else:
      new_val = xrp.apply(args)
    self.insert(stack, xrp, new_val, args)
  
    if debug:
      print "\nCHANGING ", stack, "\n  TO   :  ", new_val, "\n"
  
    if val == new_val:
      return
  
    self.rerun(False)
    new_p = self.p
    new_to_old_q = self.uneval_p - math.log(len(self.db))
    old_to_new_q += self.eval_p
    if debug:
      print "new db", self, \
            "\nq(old -> new) : ", old_to_new_q, \
            "q(new -> old) : ", new_to_old_q, \
            "p(old) : ", old_p, \
            "p(new) : ", new_p, \
            'log transition prob : ',  new_p + new_to_old_q - old_p - old_to_new_q , "\n"
  
    if old_p * old_to_new_q > 0:
      p = rrandom.random.random()
      if new_p + new_to_old_q - old_p - old_to_new_q < math.log(p):
        self.restore()
        if debug:
          print 'restore'
        self.rerun(False)
  
    if debug:
      print "new db", self
      print "\n-----------------------------------------\n"

  def __str__(self):
    string = 'DB with state:'
    string += '\n  Regular Flips:'
    for s in self.db:
      string += '\n    %s <- %s' % (self.db[s][1].val, s) 
    string += '\n  Observe Flips:'
    for s in self.db_noise:
      string += '\n    %s <- %s' % (self.db_noise[s][1].val, s) 
    return string

  def __contains__(self, stack):
    return self.has(self, stack)

  def __getitem__(self, stack):
    return self.get(self, stack)