def __init__(self): self.engine_type = 'reduced traces' self.assumes = {} # id -> evalnode self.observes = {} # id -> evalnode self.predicts = {} # id -> evalnode self.directives = [] self.db = RandomChoiceDict() self.weighted_db = WeightedRandomChoiceDict() self.choices = {} # hash -> evalnode self.xrps = {} # hash -> (xrp, set of application nodes) self.env = EnvironmentNode() self.p = 0 self.uneval_p = 0 self.eval_p = 0 self.new_to_old_q = 0 self.old_to_new_q = 0 self.debug = False # necessary because of the new XRP interface requiring some state kept while doing inference self.application_reflip = False self.reflip_node = ReducedEvalNode(self, self.env, VarExpression('')) self.nodes = [] self.old_vals = [Value()] self.new_vals = [Value()] self.old_val = Value() self.new_val = Value() self.reflip_xrp = XRP() self.mhstats_details = False self.mhstats = {} self.made_proposals = 0 self.accepted_proposals = 0 self.hashval = rrandom.random.randbelow() return
def __init__(self): self.engine_type = "traces" self.assumes = {} # id -> evalnode self.observes = {} # id -> evalnode self.predicts = {} # id -> evalnode self.directives = [] self.db = RandomChoiceDict() self.weighted_db = WeightedRandomChoiceDict() self.choices = {} # hash -> evalnode self.xrps = {} # hash -> (xrp, set of application nodes) self.env = EnvironmentNode() self.p = 0 self.old_to_new_q = 0 self.new_to_old_q = 0 self.eval_xrps = [] # (xrp, args, val) self.uneval_xrps = [] # (xrp, args, val) self.debug = False # Stuff for restoring self.application_reflip = False self.reflip_node = EvalNode(self, self.env, VarExpression("")) self.nodes = [] self.old_vals = [] self.new_vals = [] self.old_val = Value() self.new_val = Value() self.reflip_xrp = XRP() self.made_proposals = 0 self.accepted_proposals = 0 self.mhstats_details = False self.mhstats = {} return
def __init__(self): self.engine_type = 'reduced traces' self.assumes = {} # id -> evalnode self.observes = {} # id -> evalnode self.predicts = {} # id -> evalnode self.directives = [] self.db = RandomChoiceDict() self.weighted_db = WeightedRandomChoiceDict() self.choices = {} # hash -> evalnode self.xrps = {} # hash -> (xrp, set of application nodes) self.env = EnvironmentNode() self.p = 0 self.uneval_p = 0 self.eval_p = 0 self.new_to_old_q = 0 self.old_to_new_q = 0 self.debug = False # necessary because of the new XRP interface requiring some state kept while doing inference self.application_reflip = False self.reflip_node = ReducedEvalNode(self, self.env, VarExpression('')) self.nodes = [] self.old_vals = [Value()] self.new_vals = [Value()] self.old_val = Value() self.new_val = Value() self.reflip_xrp = XRP() self.mhstats_details = False self.mhstats = {} self.made_proposals = 0 self.accepted_proposals = 0 self.hashval = rrandom.random.randbelow() return
class ReducedTraces(Engine): def __init__(self): self.engine_type = 'reduced traces' self.assumes = {} # id -> evalnode self.observes = {} # id -> evalnode self.predicts = {} # id -> evalnode self.directives = [] self.db = RandomChoiceDict() self.weighted_db = WeightedRandomChoiceDict() self.choices = {} # hash -> evalnode self.xrps = {} # hash -> (xrp, set of application nodes) self.env = EnvironmentNode() self.p = 0 self.uneval_p = 0 self.eval_p = 0 self.new_to_old_q = 0 self.old_to_new_q = 0 self.debug = False # necessary because of the new XRP interface requiring some state kept while doing inference self.application_reflip = False self.reflip_node = ReducedEvalNode(self, self.env, VarExpression('')) self.nodes = [] self.old_vals = [Value()] self.new_vals = [Value()] self.old_val = Value() self.new_val = Value() self.reflip_xrp = XRP() self.mhstats_details = False self.mhstats = {} self.made_proposals = 0 self.accepted_proposals = 0 self.hashval = rrandom.random.randbelow() return def mhstats_on(self): self.mhstats_details = True def mhstats_off(self): self.mhstats_details = False def mhstats_aggregated(self): d = {} d['made-proposals'] = self.made_proposals d['accepted-proposals'] = self.accepted_proposals return d def mhstats_detailed(self): return self.mhstats def get_log_score(self, id): if id == -1: return self.p else: return self.observes[id].p def weight(self): return self.db.weight() + self.weighted_db.weight() def random_choices(self): return self.db.__len__() + self.weighted_db.__len__() def assume(self, name, expr, id=-1): evalnode = ReducedEvalNode(self, self.env, expr) self.env.add_assume(name, evalnode) evalnode.add_assume(name, self.env) if id != -1: self.assumes[id] = evalnode assert id == len(self.directives) self.directives.append('assume') val = evalnode.evaluate() return val def predict(self, expr, id): evalnode = ReducedEvalNode(self, self.env, expr) assert id == len(self.directives) self.directives.append('predict') self.predicts[id] = evalnode evalnode.predict = True val = evalnode.evaluate() return val def observe(self, expr, obs_val, id): evalnode = ReducedEvalNode(self, self.env, expr) assert id == len(self.directives) self.directives.append('observe') self.observes[id] = evalnode evalnode.observed = True evalnode.observe_val = obs_val val = evalnode.evaluate() return val def forget(self, id): if id in self.observes: d = self.observes elif id in self.predicts: d = self.predicts else: raise RException("Can only forget predicts and observes") evalnode = d[id] evalnode.unevaluate() #del d[id] return def report_value(self, id): node = self.get_directive_node(id) if not node.active: raise RException("Error. Perhaps this directive was forgotten?") val = node.val return val def get_directive_node(self, id): if self.directives[id] == 'assume': node = self.assumes[id] elif self.directives[id] == 'observe': node = self.observes[id] else: assert self.directives[id] == 'predict' node = self.predicts[id] return node def add_accepted_proposal(self, hashval): if self.mhstats_details: if hashval in self.mhstats: self.mhstats[hashval]['accepted-proposals'] += 1 else: self.mhstats[hashval] = {} self.mhstats[hashval]['accepted-proposals'] = 1 self.mhstats[hashval]['made-proposals'] = 0 self.accepted_proposals += 1 def add_made_proposal(self, hashval): if self.mhstats_details: if hashval in self.mhstats: self.mhstats[hashval]['made-proposals'] += 1 else: self.mhstats[hashval] = {} self.mhstats[hashval]['made-proposals'] = 1 self.mhstats[hashval]['accepted-proposals'] = 0 self.made_proposals += 1 def reflip(self, hashval): if self.debug: print self if hashval in self.choices: self.application_reflip = True self.reflip_node = self.choices[hashval] if not self.reflip_node.random_xrp_apply: raise RException( "Reflipping something which isn't a random xrp application" ) if self.reflip_node.val is None: raise RException( "Reflipping something which previously had value None") else: self.application_reflip = False # internal reflip (self.reflip_xrp, nodes) = self.xrps[hashval] self.nodes = list_nodes(nodes) self.eval_p = 0 self.uneval_p = 0 old_p = self.p self.old_to_new_q = -math.log(self.weight()) if self.application_reflip: self.old_val = self.reflip_node.val self.new_val = self.reflip_node.reflip() else: # TODO: this is copied from traces. is it correct? args_list = [] self.old_vals = [] for node in self.nodes: args_list.append(node.args) self.old_vals.append(node.val) self.old_to_new_q += math.log(self.reflip_xrp.state_weight()) old_p += self.reflip_xrp.theta_prob() self.new_vals, q_forwards, q_back = self.reflip_xrp.theta_mh_prop( args_list, self.old_vals) self.old_to_new_q += q_forwards self.new_to_old_q += q_back for i in range(len(self.nodes)): node = self.nodes[i] val = self.new_vals[i] node.set_val(val) node.propagate_up(False) new_p = self.p self.new_to_old_q = -math.log(self.weight()) self.old_to_new_q += self.eval_p self.new_to_old_q += self.uneval_p if not self.application_reflip: new_p += self.reflip_xrp.theta_prob() self.new_to_old_q += math.log(self.reflip_xrp.state_weight()) if self.debug: if self.application_reflip: print "\nCHANGING VAL OF ", self.reflip_node, "\n FROM : ", self.old_val, "\n TO : ", self.new_val, "\n" if (self.old_val.__eq__(self.new_val)).bool: print "SAME VAL" else: print "TRANSITIONING STATE OF ", self.reflip_xrp print "new db", self print "\nq(old -> new) : ", math.exp(self.old_to_new_q) print "q(new -> old) : ", math.exp(self.new_to_old_q) print "p(old) : ", math.exp(old_p) print "p(new) : ", math.exp(new_p) print 'transition prob : ', math.exp(new_p + self.new_to_old_q - old_p - self.old_to_new_q), "\n" print "\n-----------------------------------------\n" return old_p, self.old_to_new_q, new_p, self.new_to_old_q def restore(self): if self.application_reflip: self.reflip_node.restore(self.old_val) else: for i in range(len(self.nodes)): node = self.nodes[i] node.set_val(self.old_vals[i]) node.propagate_up(True) self.reflip_xrp.theta_mh_restore() if self.debug: print 'restore' def keep(self): if self.application_reflip: self.reflip_node.restore( self.new_val ) # NOTE: Is necessary for correctness, as we must forget old branch else: for i in range(len(self.nodes)): node = self.nodes[i] node.restore(self.new_vals[i]) self.reflip_xrp.theta_mh_keep() def add_for_transition(self, xrp, evalnode): hashval = xrp.__hash__() if hashval not in self.xrps: self.xrps[hashval] = (xrp, {}) (xrp, evalnodes) = self.xrps[hashval] evalnodes[evalnode] = True weight = xrp.state_weight() self.delete_from_db(xrp.__hash__()) self.add_to_db(xrp.__hash__(), weight) # Add an XRP application node to the db def add_xrp(self, xrp, args, evalnode): weight = xrp.weight(args) evalnode.setargs(args) try: self.new_to_old_q += math.log(weight) except: pass # This is only necessary if we're reflipping if self.weighted_db.__contains__( evalnode.hashval) or self.db.__contains__( evalnode.hashval) or (evalnode.hashval in self.choices): raise RException("DB already had this evalnode") self.choices[evalnode.hashval] = evalnode self.add_to_db(evalnode.hashval, weight) def add_to_db(self, hashval, weight): if weight == 0: return elif weight == 1: self.db.__setitem__(hashval, True) else: self.weighted_db.__setitem__(hashval, True, weight) def remove_for_transition(self, xrp, evalnode): hashval = xrp.__hash__() (xrp, evalnodes) = self.xrps[hashval] del evalnodes[evalnode] if len(evalnodes) == 0: del self.xrps[hashval] self.delete_from_db(xrp.__hash__()) def remove_xrp(self, evalnode): xrp = evalnode.xrp try: self.old_to_new_q += math.log(xrp.weight(evalnode.args)) except: pass # This fails when restoring/keeping, for example if evalnode.hashval not in self.choices: raise RException("Choices did not already have this evalnode") else: del self.choices[evalnode.hashval] self.delete_from_db(evalnode.hashval) def delete_from_db(self, hashval): if self.db.__contains__(hashval): self.db.__delitem__(hashval) elif self.weighted_db.__contains__(hashval): self.weighted_db.__delitem__(hashval) def randomKey(self): if rrandom.random.random() * self.weight() > self.db.weight(): return self.weighted_db.randomKey() else: return self.db.randomKey() def infer(self): try: hashval = self.randomKey() except: raise RException("Program has no randomness!") old_p, old_to_new_q, new_p, new_to_old_q = self.reflip(hashval) p = rrandom.random.random() if new_p + new_to_old_q - old_p - old_to_new_q < math.log(p): self.restore() else: self.keep() self.add_accepted_proposal(hashval) self.add_made_proposal(hashval) def reset(self): self.__init__() def __str__(self): string = "EvalNodeTree:" for evalnode in self.assumes.values(): string += evalnode.str_helper() for evalnode in self.observes.values(): string += evalnode.str_helper() for evalnode in self.predicts.values(): string += evalnode.str_helper() return string
class ReducedTraces(Engine): def __init__(self): self.engine_type = 'reduced traces' self.assumes = {} # id -> evalnode self.observes = {} # id -> evalnode self.predicts = {} # id -> evalnode self.directives = [] self.db = RandomChoiceDict() self.weighted_db = WeightedRandomChoiceDict() self.choices = {} # hash -> evalnode self.xrps = {} # hash -> (xrp, set of application nodes) self.env = EnvironmentNode() self.p = 0 self.uneval_p = 0 self.eval_p = 0 self.new_to_old_q = 0 self.old_to_new_q = 0 self.debug = False # necessary because of the new XRP interface requiring some state kept while doing inference self.application_reflip = False self.reflip_node = ReducedEvalNode(self, self.env, VarExpression('')) self.nodes = [] self.old_vals = [Value()] self.new_vals = [Value()] self.old_val = Value() self.new_val = Value() self.reflip_xrp = XRP() self.mhstats_details = False self.mhstats = {} self.made_proposals = 0 self.accepted_proposals = 0 self.hashval = rrandom.random.randbelow() return def mhstats_on(self): self.mhstats_details = True def mhstats_off(self): self.mhstats_details = False def mhstats_aggregated(self): d = {} d['made-proposals'] = self.made_proposals d['accepted-proposals'] = self.accepted_proposals return d def mhstats_detailed(self): return self.mhstats def get_log_score(self, id): if id == -1: return self.p else: return self.observes[id].p def weight(self): return self.db.weight() + self.weighted_db.weight() def random_choices(self): return self.db.__len__() + self.weighted_db.__len__() def assume(self, name, expr, id = -1): evalnode = ReducedEvalNode(self, self.env, expr) self.env.add_assume(name, evalnode) evalnode.add_assume(name, self.env) if id != -1: self.assumes[id] = evalnode assert id == len(self.directives) self.directives.append('assume') val = evalnode.evaluate() return val def predict(self, expr, id): evalnode = ReducedEvalNode(self, self.env, expr) assert id == len(self.directives) self.directives.append('predict') self.predicts[id] = evalnode evalnode.predict = True val = evalnode.evaluate() return val def observe(self, expr, obs_val, id): evalnode = ReducedEvalNode(self, self.env, expr) assert id == len(self.directives) self.directives.append('observe') self.observes[id] = evalnode evalnode.observed = True evalnode.observe_val = obs_val val = evalnode.evaluate() return val def forget(self, id): if id in self.observes: d = self.observes elif id in self.predicts: d = self.predicts else: raise RException("Can only forget predicts and observes") evalnode = d[id] evalnode.unevaluate() #del d[id] return def report_value(self, id): node = self.get_directive_node(id) if not node.active: raise RException("Error. Perhaps this directive was forgotten?") val = node.val return val def get_directive_node(self, id): if self.directives[id] == 'assume': node = self.assumes[id] elif self.directives[id] == 'observe': node = self.observes[id] else: assert self.directives[id] == 'predict' node = self.predicts[id] return node def add_accepted_proposal(self, hashval): if self.mhstats_details: if hashval in self.mhstats: self.mhstats[hashval]['accepted-proposals'] += 1 else: self.mhstats[hashval] = {} self.mhstats[hashval]['accepted-proposals'] = 1 self.mhstats[hashval]['made-proposals'] = 0 self.accepted_proposals += 1 def add_made_proposal(self, hashval): if self.mhstats_details: if hashval in self.mhstats: self.mhstats[hashval]['made-proposals'] += 1 else: self.mhstats[hashval] = {} self.mhstats[hashval]['made-proposals'] = 1 self.mhstats[hashval]['accepted-proposals'] = 0 self.made_proposals += 1 def reflip(self, hashval): if self.debug: print self if hashval in self.choices: self.application_reflip = True self.reflip_node = self.choices[hashval] if not self.reflip_node.random_xrp_apply: raise RException("Reflipping something which isn't a random xrp application") if self.reflip_node.val is None: raise RException("Reflipping something which previously had value None") else: self.application_reflip = False # internal reflip (self.reflip_xrp, nodes) = self.xrps[hashval] self.nodes = list_nodes(nodes) self.eval_p = 0 self.uneval_p = 0 old_p = self.p self.old_to_new_q = - math.log(self.weight()) if self.application_reflip: self.old_val = self.reflip_node.val self.new_val = self.reflip_node.reflip() else: # TODO: this is copied from traces. is it correct? args_list = [] self.old_vals = [] for node in self.nodes: args_list.append(node.args) self.old_vals.append(node.val) self.old_to_new_q += math.log(self.reflip_xrp.state_weight()) old_p += self.reflip_xrp.theta_prob() self.new_vals, q_forwards, q_back = self.reflip_xrp.theta_mh_prop(args_list, self.old_vals) self.old_to_new_q += q_forwards self.new_to_old_q += q_back for i in range(len(self.nodes)): node = self.nodes[i] val = self.new_vals[i] node.set_val(val) node.propagate_up(False) new_p = self.p self.new_to_old_q = - math.log(self.weight()) self.old_to_new_q += self.eval_p self.new_to_old_q += self.uneval_p if not self.application_reflip: new_p += self.reflip_xrp.theta_prob() self.new_to_old_q += math.log(self.reflip_xrp.state_weight()) if self.debug: if self.application_reflip: print "\nCHANGING VAL OF ", self.reflip_node, "\n FROM : ", self.old_val, "\n TO : ", self.new_val, "\n" if (self.old_val.__eq__(self.new_val)).bool: print "SAME VAL" else: print "TRANSITIONING STATE OF ", self.reflip_xrp print "new db", self print "\nq(old -> new) : ", math.exp(self.old_to_new_q) print "q(new -> old) : ", math.exp(self.new_to_old_q ) print "p(old) : ", math.exp(old_p) print "p(new) : ", math.exp(new_p) print 'transition prob : ', math.exp(new_p + self.new_to_old_q - old_p - self.old_to_new_q) , "\n" print "\n-----------------------------------------\n" return old_p, self.old_to_new_q, new_p, self.new_to_old_q def restore(self): if self.application_reflip: self.reflip_node.restore(self.old_val) else: for i in range(len(self.nodes)): node = self.nodes[i] node.set_val(self.old_vals[i]) node.propagate_up(True) self.reflip_xrp.theta_mh_restore() if self.debug: print 'restore' def keep(self): if self.application_reflip: self.reflip_node.restore(self.new_val) # NOTE: Is necessary for correctness, as we must forget old branch else: for i in range(len(self.nodes)): node = self.nodes[i] node.restore(self.new_vals[i]) self.reflip_xrp.theta_mh_keep() def add_for_transition(self, xrp, evalnode): hashval = xrp.__hash__() if hashval not in self.xrps: self.xrps[hashval] = (xrp, {}) (xrp, evalnodes) = self.xrps[hashval] evalnodes[evalnode] = True weight = xrp.state_weight() self.delete_from_db(xrp.__hash__()) self.add_to_db(xrp.__hash__(), weight) # Add an XRP application node to the db def add_xrp(self, xrp, args, evalnode): weight = xrp.weight(args) evalnode.setargs(args) try: self.new_to_old_q += math.log(weight) except: pass # This is only necessary if we're reflipping if self.weighted_db.__contains__(evalnode.hashval) or self.db.__contains__(evalnode.hashval) or (evalnode.hashval in self.choices): raise RException("DB already had this evalnode") self.choices[evalnode.hashval] = evalnode self.add_to_db(evalnode.hashval, weight) def add_to_db(self, hashval, weight): if weight == 0: return elif weight == 1: self.db.__setitem__(hashval, True) else: self.weighted_db.__setitem__(hashval, True, weight) def remove_for_transition(self, xrp, evalnode): hashval = xrp.__hash__() (xrp, evalnodes) = self.xrps[hashval] del evalnodes[evalnode] if len(evalnodes) == 0: del self.xrps[hashval] self.delete_from_db(xrp.__hash__()) def remove_xrp(self, evalnode): xrp = evalnode.xrp try: self.old_to_new_q += math.log(xrp.weight(evalnode.args)) except: pass # This fails when restoring/keeping, for example if evalnode.hashval not in self.choices: raise RException("Choices did not already have this evalnode") else: del self.choices[evalnode.hashval] self.delete_from_db(evalnode.hashval) def delete_from_db(self, hashval): if self.db.__contains__(hashval): self.db.__delitem__(hashval) elif self.weighted_db.__contains__(hashval): self.weighted_db.__delitem__(hashval) def randomKey(self): if rrandom.random.random() * self.weight() > self.db.weight(): return self.weighted_db.randomKey() else: return self.db.randomKey() def infer(self): try: hashval = self.randomKey() except: raise RException("Program has no randomness!") old_p, old_to_new_q, new_p, new_to_old_q = self.reflip(hashval) p = rrandom.random.random() if new_p + new_to_old_q - old_p - old_to_new_q < math.log(p): self.restore() else: self.keep() self.add_accepted_proposal(hashval) self.add_made_proposal(hashval) def reset(self): self.__init__() def __str__(self): string = "EvalNodeTree:" for evalnode in self.assumes.values(): string += evalnode.str_helper() for evalnode in self.observes.values(): string += evalnode.str_helper() for evalnode in self.predicts.values(): string += evalnode.str_helper() return string
class Traces(Engine): def __init__(self): self.engine_type = "traces" self.assumes = {} # id -> evalnode self.observes = {} # id -> evalnode self.predicts = {} # id -> evalnode self.directives = [] self.db = RandomChoiceDict() self.weighted_db = WeightedRandomChoiceDict() self.choices = {} # hash -> evalnode self.xrps = {} # hash -> (xrp, set of application nodes) self.env = EnvironmentNode() self.p = 0 self.old_to_new_q = 0 self.new_to_old_q = 0 self.eval_xrps = [] # (xrp, args, val) self.uneval_xrps = [] # (xrp, args, val) self.debug = False # Stuff for restoring self.application_reflip = False self.reflip_node = EvalNode(self, self.env, VarExpression("")) self.nodes = [] self.old_vals = [] self.new_vals = [] self.old_val = Value() self.new_val = Value() self.reflip_xrp = XRP() self.made_proposals = 0 self.accepted_proposals = 0 self.mhstats_details = False self.mhstats = {} return def mhstats_on(self): self.mhstats_details = True def mhstats_off(self): self.mhstats_details = False def mhstats_aggregated(self): d = {} d["made-proposals"] = self.made_proposals d["accepted-proposals"] = self.accepted_proposals return d def mhstats_detailed(self): return self.mhstats def get_log_score(self, id): if id == -1: return self.p else: return self.observes[id].p def weight(self): return self.db.weight() + self.weighted_db.weight() def random_choices(self): return self.db.__len__() + self.weighted_db.__len__() def assume(self, name, expr, id): evalnode = EvalNode(self, self.env, expr) if id != -1: self.assumes[id] = evalnode if id != len(self.directives): raise RException("Id %d does not agree with directives length of %d" % (id, len(self.directives))) self.directives.append("assume") val = evalnode.evaluate() self.env.add_assume(name, evalnode) evalnode.add_assume(name, self.env) self.env.set(name, val) return val def predict(self, expr, id): evalnode = EvalNode(self, self.env, expr) if id != len(self.directives): raise RException("Id %d does not agree with directives length of %d" % (id, len(self.directives))) self.directives.append("predict") self.predicts[id] = evalnode val = evalnode.evaluate(False) return val def observe(self, expr, obs_val, id): evalnode = EvalNode(self, self.env, expr) if id != len(self.directives): raise RException("Id %d does not agree with directives length of %d" % (id, len(self.directives))) self.directives.append("observe") self.observes[id] = evalnode evalnode.observed = True evalnode.observe_val = obs_val val = evalnode.evaluate() return val def forget(self, id): if id in self.observes: d = self.observes elif id in self.predicts: d = self.predicts else: raise RException("Can only forget predicts and observes") evalnode = d[id] evalnode.unevaluate() # del d[id] return def report_value(self, id): node = self.get_directive_node(id) if not node.active: raise RException("Error. Perhaps this directive was forgotten?") val = node.val return val def get_directive_node(self, id): if self.directives[id] == "assume": node = self.assumes[id] elif self.directives[id] == "observe": node = self.observes[id] elif self.directives[id] == "predict": node = self.predicts[id] else: raise RException("Invalid directive") return node def add_accepted_proposal(self, hashval): if self.mhstats_details: if hashval in self.mhstats: self.mhstats[hashval]["accepted-proposals"] += 1 else: self.mhstats[hashval] = {} self.mhstats[hashval]["accepted-proposals"] = 1 self.mhstats[hashval]["made-proposals"] = 0 self.accepted_proposals += 1 def add_made_proposal(self, hashval): if self.mhstats_details: if hashval in self.mhstats: self.mhstats[hashval]["made-proposals"] += 1 else: self.mhstats[hashval] = {} self.mhstats[hashval]["made-proposals"] = 1 self.mhstats[hashval]["accepted-proposals"] = 0 self.made_proposals += 1 def reflip(self, hashval): if self.debug: print self if hashval in self.choices: self.application_reflip = True self.reflip_node = self.choices[hashval] if not self.reflip_node.random_xrp_apply: raise RException("Reflipping something which isn't a random xrp application") if self.reflip_node.val is None: raise RException("Reflipping something which previously had value None") else: self.application_reflip = False # internal reflip (self.reflip_xrp, nodes) = self.xrps[hashval] self.nodes = list_nodes(nodes) self.old_to_new_q = 0 self.new_to_old_q = 0 old_p = self.p self.old_to_new_q = -math.log(self.weight()) if self.application_reflip: self.old_val = self.reflip_node.val self.new_val = self.reflip_node.reflip() else: args_list = [] self.old_vals = [] for node in self.nodes: if not node.xrp_apply: raise RException("non-XRP application node being used in transition") args_list.append(node.args) self.old_vals.append(node.val) self.old_to_new_q += math.log(self.reflip_xrp.state_weight()) old_p += self.reflip_xrp.theta_prob() self.new_vals, q_forwards, q_back = self.reflip_xrp.theta_mh_prop(args_list, self.old_vals) self.old_to_new_q += q_forwards self.new_to_old_q += q_back for i in range(len(self.nodes)): node = self.nodes[i] val = self.new_vals[i] node.set_val(val) node.propagate_up(False, True) new_p = self.p self.new_to_old_q -= math.log(self.weight()) if not self.application_reflip: new_p += self.reflip_xrp.theta_prob() self.new_to_old_q += math.log(self.reflip_xrp.state_weight()) if self.debug: if self.application_reflip: print "\nCHANGING VAL OF ", self.reflip_node, "\n FROM : ", self.old_val, "\n TO : ", self.new_val, "\n" if (self.old_val.__eq__(self.new_val)).bool: print "SAME VAL" else: print "TRANSITIONING STATE OF ", self.reflip_xrp print "new db", self print "\nq(old -> new) : ", math.exp(self.old_to_new_q) print "q(new -> old) : ", math.exp(self.new_to_old_q) print "p(old) : ", math.exp(old_p) print "p(new) : ", math.exp(new_p) print "transition prob : ", math.exp(new_p + self.new_to_old_q - old_p - self.old_to_new_q), "\n" return old_p, self.old_to_new_q, new_p, self.new_to_old_q def restore(self): if self.application_reflip: self.reflip_node.restore(self.old_val) else: for i in range(len(self.nodes)): node = self.nodes[i] node.set_val(self.old_vals[i]) node.propagate_up(True, True) self.reflip_xrp.theta_mh_restore() if self.debug: print "restore" print "\n-----------------------------------------\n" def keep(self): if self.application_reflip: self.reflip_node.restore(self.new_val) # NOTE: Is necessary for correctness, as we must forget old branch else: for i in range(len(self.nodes)): node = self.nodes[i] node.restore(self.new_vals[i]) self.reflip_xrp.theta_mh_keep() if self.debug: print "keep" print "\n-----------------------------------------\n" def add_for_transition(self, xrp, evalnode): hashval = xrp.__hash__() if hashval not in self.xrps: self.xrps[hashval] = (xrp, {}) (xrp, evalnodes) = self.xrps[hashval] evalnodes[evalnode] = True weight = xrp.state_weight() self.delete_from_db(hashval) self.add_to_db(hashval, weight) # Add an XRP application node to the db def add_xrp(self, xrp, args, val, evalnode, forcing=False): self.eval_xrps.append((xrp, args, val)) evalnode.forcing = forcing if not xrp.resample: evalnode.random_xrp_apply = True prob = xrp.prob(val, args) evalnode.setargs(xrp, args, prob) xrp.incorporate(val, args) xrp.make_link(evalnode, args) if not forcing: self.old_to_new_q += prob self.p += prob self.add_for_transition(xrp, evalnode) if xrp.resample: return weight = xrp.weight(args) try: self.new_to_old_q += math.log(weight) except: pass # This is only necessary if we're reflipping if ( self.weighted_db.__contains__(evalnode.hashval) or self.db.__contains__(evalnode.hashval) or (evalnode.hashval in self.choices) ): raise RException("DB already had this evalnode") self.choices[evalnode.hashval] = evalnode self.add_to_db(evalnode.hashval, weight) def add_to_db(self, hashval, weight): if weight == 0: return elif weight == 1: self.db.__setitem__(hashval, True) else: self.weighted_db.__setitem__(hashval, True, weight) def remove_for_transition(self, xrp, evalnode): hashval = xrp.__hash__() (xrp, evalnodes) = self.xrps[hashval] del evalnodes[evalnode] if len(evalnodes) == 0: del self.xrps[hashval] self.delete_from_db(hashval) def remove_xrp(self, xrp, args, val, evalnode): self.uneval_xrps.append((xrp, args, val)) xrp.break_link(evalnode, args) xrp.remove(val, args) prob = xrp.prob(val, args) if not evalnode.forcing: self.new_to_old_q += prob self.p -= prob self.remove_for_transition(xrp, evalnode) if xrp.resample: return xrp = evalnode.xrp try: # TODO: dont do this here.. dont do cancellign self.old_to_new_q += math.log(xrp.weight(evalnode.args)) except: pass # This fails when restoring/keeping, for example if evalnode.hashval not in self.choices: raise RException("Choices did not already have this evalnode") else: del self.choices[evalnode.hashval] self.delete_from_db(evalnode.hashval) def delete_from_db(self, hashval): if self.db.__contains__(hashval): self.db.__delitem__(hashval) elif self.weighted_db.__contains__(hashval): self.weighted_db.__delitem__(hashval) def randomKey(self): if rrandom.random.random() * self.weight() > self.db.weight(): return self.weighted_db.randomKey() else: return self.db.randomKey() def infer(self): try: hashval = self.randomKey() except: raise RException("Program has no randomness!") old_p, old_to_new_q, new_p, new_to_old_q = self.reflip(hashval) p = rrandom.random.random() if new_p + self.new_to_old_q - old_p - self.old_to_new_q < math.log(p): self.restore() else: self.keep() self.add_accepted_proposal(hashval) self.add_made_proposal(hashval) def reset(self): self.__init__() def __str__(self): string = "EvalNodeTree:" for evalnode in self.assumes.values(): string += evalnode.str_helper() for evalnode in self.observes.values(): string += evalnode.str_helper() for evalnode in self.predicts.values(): string += evalnode.str_helper() return string