class ReducedTraces(Engine): def __init__(self): self.engine_type = 'reduced traces' self.assumes = {} # id -> evalnode self.observes = {} # id -> evalnode self.predicts = {} # id -> evalnode self.directives = [] self.db = RandomChoiceDict() self.weighted_db = WeightedRandomChoiceDict() self.choices = {} # hash -> evalnode self.xrps = {} # hash -> (xrp, set of application nodes) self.env = EnvironmentNode() self.p = 0 self.uneval_p = 0 self.eval_p = 0 self.new_to_old_q = 0 self.old_to_new_q = 0 self.debug = False # necessary because of the new XRP interface requiring some state kept while doing inference self.application_reflip = False self.reflip_node = ReducedEvalNode(self, self.env, VarExpression('')) self.nodes = [] self.old_vals = [Value()] self.new_vals = [Value()] self.old_val = Value() self.new_val = Value() self.reflip_xrp = XRP() self.mhstats_details = False self.mhstats = {} self.made_proposals = 0 self.accepted_proposals = 0 self.hashval = rrandom.random.randbelow() return def mhstats_on(self): self.mhstats_details = True def mhstats_off(self): self.mhstats_details = False def mhstats_aggregated(self): d = {} d['made-proposals'] = self.made_proposals d['accepted-proposals'] = self.accepted_proposals return d def mhstats_detailed(self): return self.mhstats def get_log_score(self, id): if id == -1: return self.p else: return self.observes[id].p def weight(self): return self.db.weight() + self.weighted_db.weight() def random_choices(self): return self.db.__len__() + self.weighted_db.__len__() def assume(self, name, expr, id=-1): evalnode = ReducedEvalNode(self, self.env, expr) self.env.add_assume(name, evalnode) evalnode.add_assume(name, self.env) if id != -1: self.assumes[id] = evalnode assert id == len(self.directives) self.directives.append('assume') val = evalnode.evaluate() return val def predict(self, expr, id): evalnode = ReducedEvalNode(self, self.env, expr) assert id == len(self.directives) self.directives.append('predict') self.predicts[id] = evalnode evalnode.predict = True val = evalnode.evaluate() return val def observe(self, expr, obs_val, id): evalnode = ReducedEvalNode(self, self.env, expr) assert id == len(self.directives) self.directives.append('observe') self.observes[id] = evalnode evalnode.observed = True evalnode.observe_val = obs_val val = evalnode.evaluate() return val def forget(self, id): if id in self.observes: d = self.observes elif id in self.predicts: d = self.predicts else: raise RException("Can only forget predicts and observes") evalnode = d[id] evalnode.unevaluate() #del d[id] return def report_value(self, id): node = self.get_directive_node(id) if not node.active: raise RException("Error. Perhaps this directive was forgotten?") val = node.val return val def get_directive_node(self, id): if self.directives[id] == 'assume': node = self.assumes[id] elif self.directives[id] == 'observe': node = self.observes[id] else: assert self.directives[id] == 'predict' node = self.predicts[id] return node def add_accepted_proposal(self, hashval): if self.mhstats_details: if hashval in self.mhstats: self.mhstats[hashval]['accepted-proposals'] += 1 else: self.mhstats[hashval] = {} self.mhstats[hashval]['accepted-proposals'] = 1 self.mhstats[hashval]['made-proposals'] = 0 self.accepted_proposals += 1 def add_made_proposal(self, hashval): if self.mhstats_details: if hashval in self.mhstats: self.mhstats[hashval]['made-proposals'] += 1 else: self.mhstats[hashval] = {} self.mhstats[hashval]['made-proposals'] = 1 self.mhstats[hashval]['accepted-proposals'] = 0 self.made_proposals += 1 def reflip(self, hashval): if self.debug: print self if hashval in self.choices: self.application_reflip = True self.reflip_node = self.choices[hashval] if not self.reflip_node.random_xrp_apply: raise RException( "Reflipping something which isn't a random xrp application" ) if self.reflip_node.val is None: raise RException( "Reflipping something which previously had value None") else: self.application_reflip = False # internal reflip (self.reflip_xrp, nodes) = self.xrps[hashval] self.nodes = list_nodes(nodes) self.eval_p = 0 self.uneval_p = 0 old_p = self.p self.old_to_new_q = -math.log(self.weight()) if self.application_reflip: self.old_val = self.reflip_node.val self.new_val = self.reflip_node.reflip() else: # TODO: this is copied from traces. is it correct? args_list = [] self.old_vals = [] for node in self.nodes: args_list.append(node.args) self.old_vals.append(node.val) self.old_to_new_q += math.log(self.reflip_xrp.state_weight()) old_p += self.reflip_xrp.theta_prob() self.new_vals, q_forwards, q_back = self.reflip_xrp.theta_mh_prop( args_list, self.old_vals) self.old_to_new_q += q_forwards self.new_to_old_q += q_back for i in range(len(self.nodes)): node = self.nodes[i] val = self.new_vals[i] node.set_val(val) node.propagate_up(False) new_p = self.p self.new_to_old_q = -math.log(self.weight()) self.old_to_new_q += self.eval_p self.new_to_old_q += self.uneval_p if not self.application_reflip: new_p += self.reflip_xrp.theta_prob() self.new_to_old_q += math.log(self.reflip_xrp.state_weight()) if self.debug: if self.application_reflip: print "\nCHANGING VAL OF ", self.reflip_node, "\n FROM : ", self.old_val, "\n TO : ", self.new_val, "\n" if (self.old_val.__eq__(self.new_val)).bool: print "SAME VAL" else: print "TRANSITIONING STATE OF ", self.reflip_xrp print "new db", self print "\nq(old -> new) : ", math.exp(self.old_to_new_q) print "q(new -> old) : ", math.exp(self.new_to_old_q) print "p(old) : ", math.exp(old_p) print "p(new) : ", math.exp(new_p) print 'transition prob : ', math.exp(new_p + self.new_to_old_q - old_p - self.old_to_new_q), "\n" print "\n-----------------------------------------\n" return old_p, self.old_to_new_q, new_p, self.new_to_old_q def restore(self): if self.application_reflip: self.reflip_node.restore(self.old_val) else: for i in range(len(self.nodes)): node = self.nodes[i] node.set_val(self.old_vals[i]) node.propagate_up(True) self.reflip_xrp.theta_mh_restore() if self.debug: print 'restore' def keep(self): if self.application_reflip: self.reflip_node.restore( self.new_val ) # NOTE: Is necessary for correctness, as we must forget old branch else: for i in range(len(self.nodes)): node = self.nodes[i] node.restore(self.new_vals[i]) self.reflip_xrp.theta_mh_keep() def add_for_transition(self, xrp, evalnode): hashval = xrp.__hash__() if hashval not in self.xrps: self.xrps[hashval] = (xrp, {}) (xrp, evalnodes) = self.xrps[hashval] evalnodes[evalnode] = True weight = xrp.state_weight() self.delete_from_db(xrp.__hash__()) self.add_to_db(xrp.__hash__(), weight) # Add an XRP application node to the db def add_xrp(self, xrp, args, evalnode): weight = xrp.weight(args) evalnode.setargs(args) try: self.new_to_old_q += math.log(weight) except: pass # This is only necessary if we're reflipping if self.weighted_db.__contains__( evalnode.hashval) or self.db.__contains__( evalnode.hashval) or (evalnode.hashval in self.choices): raise RException("DB already had this evalnode") self.choices[evalnode.hashval] = evalnode self.add_to_db(evalnode.hashval, weight) def add_to_db(self, hashval, weight): if weight == 0: return elif weight == 1: self.db.__setitem__(hashval, True) else: self.weighted_db.__setitem__(hashval, True, weight) def remove_for_transition(self, xrp, evalnode): hashval = xrp.__hash__() (xrp, evalnodes) = self.xrps[hashval] del evalnodes[evalnode] if len(evalnodes) == 0: del self.xrps[hashval] self.delete_from_db(xrp.__hash__()) def remove_xrp(self, evalnode): xrp = evalnode.xrp try: self.old_to_new_q += math.log(xrp.weight(evalnode.args)) except: pass # This fails when restoring/keeping, for example if evalnode.hashval not in self.choices: raise RException("Choices did not already have this evalnode") else: del self.choices[evalnode.hashval] self.delete_from_db(evalnode.hashval) def delete_from_db(self, hashval): if self.db.__contains__(hashval): self.db.__delitem__(hashval) elif self.weighted_db.__contains__(hashval): self.weighted_db.__delitem__(hashval) def randomKey(self): if rrandom.random.random() * self.weight() > self.db.weight(): return self.weighted_db.randomKey() else: return self.db.randomKey() def infer(self): try: hashval = self.randomKey() except: raise RException("Program has no randomness!") old_p, old_to_new_q, new_p, new_to_old_q = self.reflip(hashval) p = rrandom.random.random() if new_p + new_to_old_q - old_p - old_to_new_q < math.log(p): self.restore() else: self.keep() self.add_accepted_proposal(hashval) self.add_made_proposal(hashval) def reset(self): self.__init__() def __str__(self): string = "EvalNodeTree:" for evalnode in self.assumes.values(): string += evalnode.str_helper() for evalnode in self.observes.values(): string += evalnode.str_helper() for evalnode in self.predicts.values(): string += evalnode.str_helper() return string
class ReducedTraces(Engine): def __init__(self): self.engine_type = 'reduced traces' self.assumes = {} # id -> evalnode self.observes = {} # id -> evalnode self.predicts = {} # id -> evalnode self.directives = [] self.db = RandomChoiceDict() self.weighted_db = WeightedRandomChoiceDict() self.choices = {} # hash -> evalnode self.xrps = {} # hash -> (xrp, set of application nodes) self.env = EnvironmentNode() self.p = 0 self.uneval_p = 0 self.eval_p = 0 self.new_to_old_q = 0 self.old_to_new_q = 0 self.debug = False # necessary because of the new XRP interface requiring some state kept while doing inference self.application_reflip = False self.reflip_node = ReducedEvalNode(self, self.env, VarExpression('')) self.nodes = [] self.old_vals = [Value()] self.new_vals = [Value()] self.old_val = Value() self.new_val = Value() self.reflip_xrp = XRP() self.mhstats_details = False self.mhstats = {} self.made_proposals = 0 self.accepted_proposals = 0 self.hashval = rrandom.random.randbelow() return def mhstats_on(self): self.mhstats_details = True def mhstats_off(self): self.mhstats_details = False def mhstats_aggregated(self): d = {} d['made-proposals'] = self.made_proposals d['accepted-proposals'] = self.accepted_proposals return d def mhstats_detailed(self): return self.mhstats def get_log_score(self, id): if id == -1: return self.p else: return self.observes[id].p def weight(self): return self.db.weight() + self.weighted_db.weight() def random_choices(self): return self.db.__len__() + self.weighted_db.__len__() def assume(self, name, expr, id = -1): evalnode = ReducedEvalNode(self, self.env, expr) self.env.add_assume(name, evalnode) evalnode.add_assume(name, self.env) if id != -1: self.assumes[id] = evalnode assert id == len(self.directives) self.directives.append('assume') val = evalnode.evaluate() return val def predict(self, expr, id): evalnode = ReducedEvalNode(self, self.env, expr) assert id == len(self.directives) self.directives.append('predict') self.predicts[id] = evalnode evalnode.predict = True val = evalnode.evaluate() return val def observe(self, expr, obs_val, id): evalnode = ReducedEvalNode(self, self.env, expr) assert id == len(self.directives) self.directives.append('observe') self.observes[id] = evalnode evalnode.observed = True evalnode.observe_val = obs_val val = evalnode.evaluate() return val def forget(self, id): if id in self.observes: d = self.observes elif id in self.predicts: d = self.predicts else: raise RException("Can only forget predicts and observes") evalnode = d[id] evalnode.unevaluate() #del d[id] return def report_value(self, id): node = self.get_directive_node(id) if not node.active: raise RException("Error. Perhaps this directive was forgotten?") val = node.val return val def get_directive_node(self, id): if self.directives[id] == 'assume': node = self.assumes[id] elif self.directives[id] == 'observe': node = self.observes[id] else: assert self.directives[id] == 'predict' node = self.predicts[id] return node def add_accepted_proposal(self, hashval): if self.mhstats_details: if hashval in self.mhstats: self.mhstats[hashval]['accepted-proposals'] += 1 else: self.mhstats[hashval] = {} self.mhstats[hashval]['accepted-proposals'] = 1 self.mhstats[hashval]['made-proposals'] = 0 self.accepted_proposals += 1 def add_made_proposal(self, hashval): if self.mhstats_details: if hashval in self.mhstats: self.mhstats[hashval]['made-proposals'] += 1 else: self.mhstats[hashval] = {} self.mhstats[hashval]['made-proposals'] = 1 self.mhstats[hashval]['accepted-proposals'] = 0 self.made_proposals += 1 def reflip(self, hashval): if self.debug: print self if hashval in self.choices: self.application_reflip = True self.reflip_node = self.choices[hashval] if not self.reflip_node.random_xrp_apply: raise RException("Reflipping something which isn't a random xrp application") if self.reflip_node.val is None: raise RException("Reflipping something which previously had value None") else: self.application_reflip = False # internal reflip (self.reflip_xrp, nodes) = self.xrps[hashval] self.nodes = list_nodes(nodes) self.eval_p = 0 self.uneval_p = 0 old_p = self.p self.old_to_new_q = - math.log(self.weight()) if self.application_reflip: self.old_val = self.reflip_node.val self.new_val = self.reflip_node.reflip() else: # TODO: this is copied from traces. is it correct? args_list = [] self.old_vals = [] for node in self.nodes: args_list.append(node.args) self.old_vals.append(node.val) self.old_to_new_q += math.log(self.reflip_xrp.state_weight()) old_p += self.reflip_xrp.theta_prob() self.new_vals, q_forwards, q_back = self.reflip_xrp.theta_mh_prop(args_list, self.old_vals) self.old_to_new_q += q_forwards self.new_to_old_q += q_back for i in range(len(self.nodes)): node = self.nodes[i] val = self.new_vals[i] node.set_val(val) node.propagate_up(False) new_p = self.p self.new_to_old_q = - math.log(self.weight()) self.old_to_new_q += self.eval_p self.new_to_old_q += self.uneval_p if not self.application_reflip: new_p += self.reflip_xrp.theta_prob() self.new_to_old_q += math.log(self.reflip_xrp.state_weight()) if self.debug: if self.application_reflip: print "\nCHANGING VAL OF ", self.reflip_node, "\n FROM : ", self.old_val, "\n TO : ", self.new_val, "\n" if (self.old_val.__eq__(self.new_val)).bool: print "SAME VAL" else: print "TRANSITIONING STATE OF ", self.reflip_xrp print "new db", self print "\nq(old -> new) : ", math.exp(self.old_to_new_q) print "q(new -> old) : ", math.exp(self.new_to_old_q ) print "p(old) : ", math.exp(old_p) print "p(new) : ", math.exp(new_p) print 'transition prob : ', math.exp(new_p + self.new_to_old_q - old_p - self.old_to_new_q) , "\n" print "\n-----------------------------------------\n" return old_p, self.old_to_new_q, new_p, self.new_to_old_q def restore(self): if self.application_reflip: self.reflip_node.restore(self.old_val) else: for i in range(len(self.nodes)): node = self.nodes[i] node.set_val(self.old_vals[i]) node.propagate_up(True) self.reflip_xrp.theta_mh_restore() if self.debug: print 'restore' def keep(self): if self.application_reflip: self.reflip_node.restore(self.new_val) # NOTE: Is necessary for correctness, as we must forget old branch else: for i in range(len(self.nodes)): node = self.nodes[i] node.restore(self.new_vals[i]) self.reflip_xrp.theta_mh_keep() def add_for_transition(self, xrp, evalnode): hashval = xrp.__hash__() if hashval not in self.xrps: self.xrps[hashval] = (xrp, {}) (xrp, evalnodes) = self.xrps[hashval] evalnodes[evalnode] = True weight = xrp.state_weight() self.delete_from_db(xrp.__hash__()) self.add_to_db(xrp.__hash__(), weight) # Add an XRP application node to the db def add_xrp(self, xrp, args, evalnode): weight = xrp.weight(args) evalnode.setargs(args) try: self.new_to_old_q += math.log(weight) except: pass # This is only necessary if we're reflipping if self.weighted_db.__contains__(evalnode.hashval) or self.db.__contains__(evalnode.hashval) or (evalnode.hashval in self.choices): raise RException("DB already had this evalnode") self.choices[evalnode.hashval] = evalnode self.add_to_db(evalnode.hashval, weight) def add_to_db(self, hashval, weight): if weight == 0: return elif weight == 1: self.db.__setitem__(hashval, True) else: self.weighted_db.__setitem__(hashval, True, weight) def remove_for_transition(self, xrp, evalnode): hashval = xrp.__hash__() (xrp, evalnodes) = self.xrps[hashval] del evalnodes[evalnode] if len(evalnodes) == 0: del self.xrps[hashval] self.delete_from_db(xrp.__hash__()) def remove_xrp(self, evalnode): xrp = evalnode.xrp try: self.old_to_new_q += math.log(xrp.weight(evalnode.args)) except: pass # This fails when restoring/keeping, for example if evalnode.hashval not in self.choices: raise RException("Choices did not already have this evalnode") else: del self.choices[evalnode.hashval] self.delete_from_db(evalnode.hashval) def delete_from_db(self, hashval): if self.db.__contains__(hashval): self.db.__delitem__(hashval) elif self.weighted_db.__contains__(hashval): self.weighted_db.__delitem__(hashval) def randomKey(self): if rrandom.random.random() * self.weight() > self.db.weight(): return self.weighted_db.randomKey() else: return self.db.randomKey() def infer(self): try: hashval = self.randomKey() except: raise RException("Program has no randomness!") old_p, old_to_new_q, new_p, new_to_old_q = self.reflip(hashval) p = rrandom.random.random() if new_p + new_to_old_q - old_p - old_to_new_q < math.log(p): self.restore() else: self.keep() self.add_accepted_proposal(hashval) self.add_made_proposal(hashval) def reset(self): self.__init__() def __str__(self): string = "EvalNodeTree:" for evalnode in self.assumes.values(): string += evalnode.str_helper() for evalnode in self.observes.values(): string += evalnode.str_helper() for evalnode in self.predicts.values(): string += evalnode.str_helper() return string
class RandomDB(Engine): def __init__(self): #self.db = {} self.engine_type = 'randomdb' self.db = RandomChoiceDict() self.db_noise = {} self.log = [] # ALWAYS WORKING WITH LOG PROBABILITIES self.uneval_p = 0 self.eval_p = 0 self.p = 0 self.env = Environment() self.assumes = {} self.observes = {} self.predicts = {} self.vars = {} def reset(self): self.__init__() def assume(self, varname, expr, id): self.assumes[id] = (varname, expr) self.vars[varname] = expr value = self.evaluate(expr, self.env, reflip=True, stack=[id]) self.env.set(varname, value) return value def observe(self, expr, obs_val, id): if expr.hashval in self.observes: raise RException('Already observed %s' % str(expr)) self.observes[id] = (expr, obs_val) # bit of a hack, here, to make it recognize same things as with noisy_expr self.evaluate(expr, self.env, reflip=False, stack=[id], xrp_force_val=obs_val) return expr.hashval def predict(self, expr, id): self.predicts[id] = expr return self.evaluate(expr, self.env, True, [id]) def forget(self, id): self.remove(['obs', id]) assert id in self.observes del self.observes[id] return def insert(self, stack, xrp, value, args, is_obs_noise=False, memorize=True): stack = tuple(stack) if self.has(stack): self.remove(stack) prob = xrp.prob(value, args) self.p += prob xrp.incorporate(value, args) if is_obs_noise: self.db_noise[stack] = (xrp, value, args, True) else: self.db[stack] = (xrp, value, args, False) if not is_obs_noise: self.eval_p += prob # hmmm.. if memorize: self.log.append(('insert', stack, xrp, value, args, is_obs_noise)) def remove(self, stack, memorize=True): stack = tuple(stack) assert self.has(stack) (xrp, value, args, is_obs_noise) = self.get(stack) xrp.remove(value, args) prob = xrp.prob(value, args) self.p -= prob if is_obs_noise: del self.db_noise[stack] else: del self.db[stack] self.uneval_p += prob # previously unindented... if memorize: self.log.append(('remove', stack, xrp, value, args, is_obs_noise)) def has(self, stack): stack = tuple(stack) return ((stack in self.db) or (stack in self.db_noise)) def get(self, stack): stack = tuple(stack) if stack in self.db: return self.db[stack] elif stack in self.db_noise: return self.db_noise[stack] else: raise RException('Failed to get stack %s' % str(stack)) def infer(self): try: stack = self.random_stack() except: return self.reflip(stack) def random_stack(self): key = self.db.randomKey() return key def evaluate_recurse(self, subexpr, env, reflip, stack, addition): newstack = [s for s in stack].extend(addition) val = self.evaluate(subexpr, env, reflip, newstack) return val def binary_op_evaluate(self, expr, env, reflip, stack): val1 = self.evaluate_recurse(expr.children[0], env, reflip, stack, ['operand0']) val2 = self.evaluate_recurse(expr.children[1], env, reflip, stack, ['operand1']) return (val1, val2) def children_evaluate(self, expr, env, reflip, stack): return [ self.evaluate_recurse(expr.children[i], env, reflip, stack, ['child' + str(i)]) for i in range(len(expr.children)) ] # Draws a sample value (without re-sampling other values) given its parents, and sets it def evaluate(self, expr, env=None, reflip=False, stack=[], xrp_force_val=None): if env is None: env = self.env if xrp_force_val is not None: assert expr.type == 'apply' if expr.type == 'value': val = expr.val elif expr.type == 'variable': var = expr.name (val, lookup_env) = env.lookup(var) elif expr.type == 'if': cond = self.evaluate_recurse(expr.cond, env, reflip, stack, ['cond']) if cond.bool: self.unevaluate(stack + ['false']) val = self.evaluate_recurse(expr.true, env, reflip, stack, ['true']) else: self.unevaluate(stack + ['true']) val = self.evaluate_recurse(expr.false, env, reflip, stack, ['false']) #elif expr.type == 'switch': # index = self.evaluate_recurse(expr.index, env, reflip, stack , ['index']) # assert 0 <= index.num < expr.n # for i in range(expr.n): # if i != index.num: # self.unevaluate(stack + ['child' + str(i)]) # val = self.evaluate_recurse(expr.children[index.num], env, reflip, stack, ['child' + str(index.num)]) elif expr.type == 'let': # NOTE: this really is a let* n = len(expr.vars) assert len(expr.expressions) == n values = [] new_env = env for i in range(n): # Bind variables new_env = new_env.spawn_child() val = self.evaluate_recurse(expr.expressions[i], new_env, reflip, stack, ['let' + str(i)]) values.append(val) new_env.set(expr.vars[i], values[i]) if val.type == 'procedure': val.env = new_env new_body = expr.body.replace(new_env) val = self.evaluate_recurse(new_body, new_env, reflip, stack, ['body']) elif expr.type == 'apply': n = len(expr.children) args = [ self.evaluate_recurse(expr.children[i], env, reflip, stack, ['arg' + str(i)]) for i in range(n) ] op = self.evaluate_recurse(expr.op, env, reflip, stack, ['operator']) addition = ','.join([x.str_hash for x in args]) if op.type == 'procedure': self.unevaluate(stack + ['apply'], addition) if n != len(op.vars): raise RException( 'Procedure should have %d arguments. \nVars were \n%s\n, but children were \n%s.' % (n, op.vars, expr.children)) new_env = op.env.spawn_child() for i in range(n): new_env.set(op.vars[i], args[i]) val = self.evaluate_recurse(op.body, new_env, reflip, stack, ['apply', addition]) elif op.type == 'xrp': self.unevaluate(stack + ['apply'], addition) if xrp_force_val is not None: assert not reflip if self.has(stack): self.remove(stack) self.insert(stack, op.xrp, xrp_force_val, args, True) val = xrp_force_val else: substack = stack + ['apply', addition] if not self.has(substack): if op.xrp.is_mem_proc(): val = op.xrp.apply_mem(args, self, stack) else: val = op.xrp.apply(args) self.insert(substack, op.xrp, val, args) else: if reflip: self.remove(substack) if op.xrp.is_mem_proc(): val = op.xrp.apply_mem(args, self, stack) else: val = op.xrp.apply(args) self.insert(substack, op.xrp, val, args) else: (xrp, val, dbargs, is_obs_noise) = self.get(substack) assert not is_obs_noise else: raise RException( 'Must apply either a procedure or xrp. Instead got expression %s' % str(op)) elif expr.type == 'function': n = len(expr.vars) new_env = env.spawn_child() bound = {} for i in range(n): # Bind variables bound[expr.vars[i]] = True procedure_body = expr.body.replace(new_env, bound) val = Procedure(expr.vars, procedure_body, env) elif expr.type == '=': (val1, val2) = self.binary_op_evaluate(expr, env, reflip, stack) val = val1.__eq__(val2) elif expr.type == '<': (val1, val2) = self.binary_op_evaluate(expr, env, reflip, stack) val = val1.__lt__(val2) elif expr.type == '>': (val1, val2) = self.binary_op_evaluate(expr, env, reflip, stack) val = val1.__gt__(val2) elif expr.type == '<=': (val1, val2) = self.binary_op_evaluate(expr, env, reflip, stack) val = val1.__le__(val2) elif expr.type == '>=': (val1, val2) = self.binary_op_evaluate(expr, env, reflip, stack) val = val1.__ge__(val2) elif expr.type == '&': vals = self.children_evaluate(expr, env, reflip, stack) andval = BoolValue(True) for x in vals: andval = andval.__and__(x.bool) val = andval elif expr.type == '^': vals = self.children_evaluate(expr, env, reflip, stack) xorval = BoolValue(True) for x in vals: xorval = xorval.__xor__(x.bool) val = xorval elif expr.type == '|': vals = self.children_evaluate(expr, env, reflip, stack) orval = BoolValue(False) for x in vals: orval = orval.__or__(x.bool) val = orval elif expr.type == '~': negval = self.evaluate_recurse(expr.children[0], env, reflip, stack, ['neg']) val = negval.__inv__() elif expr.type == '+': vals = self.children_evaluate(expr, env, reflip, stack) sum_val = NatValue(0) for x in vals: sum_val = sum_val.__add__(x) val = sum_val elif expr.type == '-': val1 = self.evaluate_recurse(expr.children[0], env, reflip, stack, ['sub0']) val2 = self.evaluate_recurse(expr.children[1], env, reflip, stack, ['sub1']) val = val1.__sub__(val2) elif expr.type == '*': vals = self.children_evaluate(expr, env, reflip, stack) prod_val = NatValue(1) for x in vals: prod_val = prod_val.__mul__(x) val = prod_val elif expr.type == '/': val1 = self.evaluate_recurse(expr.children[0], env, reflip, stack, ['div0']) val2 = self.evaluate_recurse(expr.children[1], env, reflip, stack, ['div1']) val = val1.__div__(val2) else: raise RException('Invalid expression type %s' % expr.type) return val def unevaluate(self, uneval_stack, args=None): if args is not None: args = tuple(args) to_unevaluate = [] for tuple_stack in self.db: to_unevaluate.append(tuple_stack) for tuple_stack in self.db_noise: to_unevaluate.append(tuple_stack) to_delete = [] for tuple_stack in to_unevaluate: stack = list(tuple_stack) if len(stack) >= len(uneval_stack) and stack[:len( uneval_stack)] == uneval_stack: if args is None: to_delete.append(tuple_stack) else: assert len(stack) > len(uneval_stack) if stack[len(uneval_stack)] != args: to_delete.append(tuple_stack) for tuple_stack in to_delete: self.remove(tuple_stack) def save(self): self.log = [] self.uneval_p = 0 self.eval_p = 0 def restore(self): self.log.reverse() for (type, stack, xrp, value, args, is_obs_noise) in self.log: if type == 'insert': self.remove(stack, False) else: assert type == 'remove' self.insert(stack, xrp, value, args, is_obs_noise, False) def reflip(self, stack): (xrp, val, args, is_obs_noise) = self.get(stack) #debug = True debug = False old_p = self.p old_to_new_q = -math.log(len(self.db)) if debug: print "old_db", self self.save() self.remove(stack) if xrp.is_mem_proc(): new_val = xrp.apply(args, list(stack)) else: new_val = xrp.apply(args) self.insert(stack, xrp, new_val, args) if debug: print "\nCHANGING ", stack, "\n TO : ", new_val, "\n" if val == new_val: return self.rerun(False) new_p = self.p new_to_old_q = self.uneval_p - math.log(len(self.db)) old_to_new_q += self.eval_p if debug: print "new db", self, \ "\nq(old -> new) : ", old_to_new_q, \ "q(new -> old) : ", new_to_old_q, \ "p(old) : ", old_p, \ "p(new) : ", new_p, \ 'log transition prob : ', new_p + new_to_old_q - old_p - old_to_new_q , "\n" if old_p * old_to_new_q > 0: p = rrandom.random.random() if new_p + new_to_old_q - old_p - old_to_new_q < math.log(p): self.restore() if debug: print 'restore' self.rerun(False) if debug: print "new db", self print "\n-----------------------------------------\n" def __str__(self): string = 'DB with state:' string += '\n Regular Flips:' for s in self.db: string += '\n %s <- %s' % (self.db[s][1].val, s) string += '\n Observe Flips:' for s in self.db_noise: string += '\n %s <- %s' % (self.db_noise[s][1].val, s) return string def __contains__(self, stack): return self.has(self, stack) def __getitem__(self, stack): return self.get(self, stack)
class Traces(Engine): def __init__(self): self.engine_type = "traces" self.assumes = {} # id -> evalnode self.observes = {} # id -> evalnode self.predicts = {} # id -> evalnode self.directives = [] self.db = RandomChoiceDict() self.weighted_db = WeightedRandomChoiceDict() self.choices = {} # hash -> evalnode self.xrps = {} # hash -> (xrp, set of application nodes) self.env = EnvironmentNode() self.p = 0 self.old_to_new_q = 0 self.new_to_old_q = 0 self.eval_xrps = [] # (xrp, args, val) self.uneval_xrps = [] # (xrp, args, val) self.debug = False # Stuff for restoring self.application_reflip = False self.reflip_node = EvalNode(self, self.env, VarExpression("")) self.nodes = [] self.old_vals = [] self.new_vals = [] self.old_val = Value() self.new_val = Value() self.reflip_xrp = XRP() self.made_proposals = 0 self.accepted_proposals = 0 self.mhstats_details = False self.mhstats = {} return def mhstats_on(self): self.mhstats_details = True def mhstats_off(self): self.mhstats_details = False def mhstats_aggregated(self): d = {} d["made-proposals"] = self.made_proposals d["accepted-proposals"] = self.accepted_proposals return d def mhstats_detailed(self): return self.mhstats def get_log_score(self, id): if id == -1: return self.p else: return self.observes[id].p def weight(self): return self.db.weight() + self.weighted_db.weight() def random_choices(self): return self.db.__len__() + self.weighted_db.__len__() def assume(self, name, expr, id): evalnode = EvalNode(self, self.env, expr) if id != -1: self.assumes[id] = evalnode if id != len(self.directives): raise RException("Id %d does not agree with directives length of %d" % (id, len(self.directives))) self.directives.append("assume") val = evalnode.evaluate() self.env.add_assume(name, evalnode) evalnode.add_assume(name, self.env) self.env.set(name, val) return val def predict(self, expr, id): evalnode = EvalNode(self, self.env, expr) if id != len(self.directives): raise RException("Id %d does not agree with directives length of %d" % (id, len(self.directives))) self.directives.append("predict") self.predicts[id] = evalnode val = evalnode.evaluate(False) return val def observe(self, expr, obs_val, id): evalnode = EvalNode(self, self.env, expr) if id != len(self.directives): raise RException("Id %d does not agree with directives length of %d" % (id, len(self.directives))) self.directives.append("observe") self.observes[id] = evalnode evalnode.observed = True evalnode.observe_val = obs_val val = evalnode.evaluate() return val def forget(self, id): if id in self.observes: d = self.observes elif id in self.predicts: d = self.predicts else: raise RException("Can only forget predicts and observes") evalnode = d[id] evalnode.unevaluate() # del d[id] return def report_value(self, id): node = self.get_directive_node(id) if not node.active: raise RException("Error. Perhaps this directive was forgotten?") val = node.val return val def get_directive_node(self, id): if self.directives[id] == "assume": node = self.assumes[id] elif self.directives[id] == "observe": node = self.observes[id] elif self.directives[id] == "predict": node = self.predicts[id] else: raise RException("Invalid directive") return node def add_accepted_proposal(self, hashval): if self.mhstats_details: if hashval in self.mhstats: self.mhstats[hashval]["accepted-proposals"] += 1 else: self.mhstats[hashval] = {} self.mhstats[hashval]["accepted-proposals"] = 1 self.mhstats[hashval]["made-proposals"] = 0 self.accepted_proposals += 1 def add_made_proposal(self, hashval): if self.mhstats_details: if hashval in self.mhstats: self.mhstats[hashval]["made-proposals"] += 1 else: self.mhstats[hashval] = {} self.mhstats[hashval]["made-proposals"] = 1 self.mhstats[hashval]["accepted-proposals"] = 0 self.made_proposals += 1 def reflip(self, hashval): if self.debug: print self if hashval in self.choices: self.application_reflip = True self.reflip_node = self.choices[hashval] if not self.reflip_node.random_xrp_apply: raise RException("Reflipping something which isn't a random xrp application") if self.reflip_node.val is None: raise RException("Reflipping something which previously had value None") else: self.application_reflip = False # internal reflip (self.reflip_xrp, nodes) = self.xrps[hashval] self.nodes = list_nodes(nodes) self.old_to_new_q = 0 self.new_to_old_q = 0 old_p = self.p self.old_to_new_q = -math.log(self.weight()) if self.application_reflip: self.old_val = self.reflip_node.val self.new_val = self.reflip_node.reflip() else: args_list = [] self.old_vals = [] for node in self.nodes: if not node.xrp_apply: raise RException("non-XRP application node being used in transition") args_list.append(node.args) self.old_vals.append(node.val) self.old_to_new_q += math.log(self.reflip_xrp.state_weight()) old_p += self.reflip_xrp.theta_prob() self.new_vals, q_forwards, q_back = self.reflip_xrp.theta_mh_prop(args_list, self.old_vals) self.old_to_new_q += q_forwards self.new_to_old_q += q_back for i in range(len(self.nodes)): node = self.nodes[i] val = self.new_vals[i] node.set_val(val) node.propagate_up(False, True) new_p = self.p self.new_to_old_q -= math.log(self.weight()) if not self.application_reflip: new_p += self.reflip_xrp.theta_prob() self.new_to_old_q += math.log(self.reflip_xrp.state_weight()) if self.debug: if self.application_reflip: print "\nCHANGING VAL OF ", self.reflip_node, "\n FROM : ", self.old_val, "\n TO : ", self.new_val, "\n" if (self.old_val.__eq__(self.new_val)).bool: print "SAME VAL" else: print "TRANSITIONING STATE OF ", self.reflip_xrp print "new db", self print "\nq(old -> new) : ", math.exp(self.old_to_new_q) print "q(new -> old) : ", math.exp(self.new_to_old_q) print "p(old) : ", math.exp(old_p) print "p(new) : ", math.exp(new_p) print "transition prob : ", math.exp(new_p + self.new_to_old_q - old_p - self.old_to_new_q), "\n" return old_p, self.old_to_new_q, new_p, self.new_to_old_q def restore(self): if self.application_reflip: self.reflip_node.restore(self.old_val) else: for i in range(len(self.nodes)): node = self.nodes[i] node.set_val(self.old_vals[i]) node.propagate_up(True, True) self.reflip_xrp.theta_mh_restore() if self.debug: print "restore" print "\n-----------------------------------------\n" def keep(self): if self.application_reflip: self.reflip_node.restore(self.new_val) # NOTE: Is necessary for correctness, as we must forget old branch else: for i in range(len(self.nodes)): node = self.nodes[i] node.restore(self.new_vals[i]) self.reflip_xrp.theta_mh_keep() if self.debug: print "keep" print "\n-----------------------------------------\n" def add_for_transition(self, xrp, evalnode): hashval = xrp.__hash__() if hashval not in self.xrps: self.xrps[hashval] = (xrp, {}) (xrp, evalnodes) = self.xrps[hashval] evalnodes[evalnode] = True weight = xrp.state_weight() self.delete_from_db(hashval) self.add_to_db(hashval, weight) # Add an XRP application node to the db def add_xrp(self, xrp, args, val, evalnode, forcing=False): self.eval_xrps.append((xrp, args, val)) evalnode.forcing = forcing if not xrp.resample: evalnode.random_xrp_apply = True prob = xrp.prob(val, args) evalnode.setargs(xrp, args, prob) xrp.incorporate(val, args) xrp.make_link(evalnode, args) if not forcing: self.old_to_new_q += prob self.p += prob self.add_for_transition(xrp, evalnode) if xrp.resample: return weight = xrp.weight(args) try: self.new_to_old_q += math.log(weight) except: pass # This is only necessary if we're reflipping if ( self.weighted_db.__contains__(evalnode.hashval) or self.db.__contains__(evalnode.hashval) or (evalnode.hashval in self.choices) ): raise RException("DB already had this evalnode") self.choices[evalnode.hashval] = evalnode self.add_to_db(evalnode.hashval, weight) def add_to_db(self, hashval, weight): if weight == 0: return elif weight == 1: self.db.__setitem__(hashval, True) else: self.weighted_db.__setitem__(hashval, True, weight) def remove_for_transition(self, xrp, evalnode): hashval = xrp.__hash__() (xrp, evalnodes) = self.xrps[hashval] del evalnodes[evalnode] if len(evalnodes) == 0: del self.xrps[hashval] self.delete_from_db(hashval) def remove_xrp(self, xrp, args, val, evalnode): self.uneval_xrps.append((xrp, args, val)) xrp.break_link(evalnode, args) xrp.remove(val, args) prob = xrp.prob(val, args) if not evalnode.forcing: self.new_to_old_q += prob self.p -= prob self.remove_for_transition(xrp, evalnode) if xrp.resample: return xrp = evalnode.xrp try: # TODO: dont do this here.. dont do cancellign self.old_to_new_q += math.log(xrp.weight(evalnode.args)) except: pass # This fails when restoring/keeping, for example if evalnode.hashval not in self.choices: raise RException("Choices did not already have this evalnode") else: del self.choices[evalnode.hashval] self.delete_from_db(evalnode.hashval) def delete_from_db(self, hashval): if self.db.__contains__(hashval): self.db.__delitem__(hashval) elif self.weighted_db.__contains__(hashval): self.weighted_db.__delitem__(hashval) def randomKey(self): if rrandom.random.random() * self.weight() > self.db.weight(): return self.weighted_db.randomKey() else: return self.db.randomKey() def infer(self): try: hashval = self.randomKey() except: raise RException("Program has no randomness!") old_p, old_to_new_q, new_p, new_to_old_q = self.reflip(hashval) p = rrandom.random.random() if new_p + self.new_to_old_q - old_p - self.old_to_new_q < math.log(p): self.restore() else: self.keep() self.add_accepted_proposal(hashval) self.add_made_proposal(hashval) def reset(self): self.__init__() def __str__(self): string = "EvalNodeTree:" for evalnode in self.assumes.values(): string += evalnode.str_helper() for evalnode in self.observes.values(): string += evalnode.str_helper() for evalnode in self.predicts.values(): string += evalnode.str_helper() return string
class RandomDB(Engine): def __init__(self): #self.db = {} self.engine_type = 'randomdb' self.db = RandomChoiceDict() self.db_noise = {} self.log = [] # ALWAYS WORKING WITH LOG PROBABILITIES self.uneval_p = 0 self.eval_p = 0 self.p = 0 self.env = Environment() self.assumes = {} self.observes = {} self.predicts = {} self.vars = {} def reset(self): self.__init__() def assume(self, varname, expr, id): self.assumes[id] = (varname, expr) self.vars[varname] = expr value = self.evaluate(expr, self.env, reflip = True, stack = [id]) self.env.set(varname, value) return value def observe(self, expr, obs_val, id): if expr.hashval in self.observes: raise RException('Already observed %s' % str(expr)) self.observes[id] = (expr, obs_val) # bit of a hack, here, to make it recognize same things as with noisy_expr self.evaluate(expr, self.env, reflip = False, stack = [id], xrp_force_val = obs_val) return expr.hashval def predict(self, expr, id): self.predicts[id] = expr return self.evaluate(expr, self.env, True, [id]) def forget(self, id): self.remove(['obs', id]) assert id in self.observes del self.observes[id] return def insert(self, stack, xrp, value, args, is_obs_noise = False, memorize = True): stack = tuple(stack) if self.has(stack): self.remove(stack) prob = xrp.prob(value, args) self.p += prob xrp.incorporate(value, args) if is_obs_noise: self.db_noise[stack] = (xrp, value, args, True) else: self.db[stack] = (xrp, value, args, False) if not is_obs_noise: self.eval_p += prob # hmmm.. if memorize: self.log.append(('insert', stack, xrp, value, args, is_obs_noise)) def remove(self, stack, memorize = True): stack = tuple(stack) assert self.has(stack) (xrp, value, args, is_obs_noise) = self.get(stack) xrp.remove(value, args) prob = xrp.prob(value, args) self.p -= prob if is_obs_noise: del self.db_noise[stack] else: del self.db[stack] self.uneval_p += prob # previously unindented... if memorize: self.log.append(('remove', stack, xrp, value, args, is_obs_noise)) def has(self, stack): stack = tuple(stack) return ((stack in self.db) or (stack in self.db_noise)) def get(self, stack): stack = tuple(stack) if stack in self.db: return self.db[stack] elif stack in self.db_noise: return self.db_noise[stack] else: raise RException('Failed to get stack %s' % str(stack)) def infer(self): try: stack = self.random_stack() except: return self.reflip(stack) def random_stack(self): key = self.db.randomKey() return key def evaluate_recurse(self, subexpr, env, reflip, stack, addition): newstack = [s for s in stack].extend(addition) val = self.evaluate(subexpr, env, reflip, newstack) return val def binary_op_evaluate(self, expr, env, reflip, stack): val1 = self.evaluate_recurse(expr.children[0], env, reflip, stack, ['operand0']) val2 = self.evaluate_recurse(expr.children[1], env, reflip, stack, ['operand1']) return (val1 , val2) def children_evaluate(self, expr, env, reflip, stack): return [self.evaluate_recurse(expr.children[i], env, reflip, stack, ['child' + str(i)]) for i in range(len(expr.children))] # Draws a sample value (without re-sampling other values) given its parents, and sets it def evaluate(self, expr, env = None, reflip = False, stack = [], xrp_force_val = None): if env is None: env = self.env if xrp_force_val is not None: assert expr.type == 'apply' if expr.type == 'value': val = expr.val elif expr.type == 'variable': var = expr.name (val, lookup_env) = env.lookup(var) elif expr.type == 'if': cond = self.evaluate_recurse(expr.cond, env, reflip, stack , ['cond']) if cond.bool: self.unevaluate(stack + ['false']) val = self.evaluate_recurse(expr.true, env, reflip, stack , ['true']) else: self.unevaluate(stack + ['true']) val = self.evaluate_recurse(expr.false, env, reflip, stack , ['false']) #elif expr.type == 'switch': # index = self.evaluate_recurse(expr.index, env, reflip, stack , ['index']) # assert 0 <= index.num < expr.n # for i in range(expr.n): # if i != index.num: # self.unevaluate(stack + ['child' + str(i)]) # val = self.evaluate_recurse(expr.children[index.num], env, reflip, stack, ['child' + str(index.num)]) elif expr.type == 'let': # NOTE: this really is a let* n = len(expr.vars) assert len(expr.expressions) == n values = [] new_env = env for i in range(n): # Bind variables new_env = new_env.spawn_child() val = self.evaluate_recurse(expr.expressions[i], new_env, reflip, stack, ['let' + str(i)]) values.append(val) new_env.set(expr.vars[i], values[i]) if val.type == 'procedure': val.env = new_env new_body = expr.body.replace(new_env) val = self.evaluate_recurse(new_body, new_env, reflip, stack, ['body']) elif expr.type == 'apply': n = len(expr.children) args = [self.evaluate_recurse(expr.children[i], env, reflip, stack, ['arg' + str(i)]) for i in range(n)] op = self.evaluate_recurse(expr.op, env, reflip, stack , ['operator']) addition = ','.join([x.str_hash for x in args]) if op.type == 'procedure': self.unevaluate(stack + ['apply'], addition) if n != len(op.vars): raise RException('Procedure should have %d arguments. \nVars were \n%s\n, but children were \n%s.' % (n, op.vars, expr.children)) new_env = op.env.spawn_child() for i in range(n): new_env.set(op.vars[i], args[i]) val = self.evaluate_recurse(op.body, new_env, reflip, stack, ['apply', addition]) elif op.type == 'xrp': self.unevaluate(stack + ['apply'], addition) if xrp_force_val is not None: assert not reflip if self.has(stack): self.remove(stack) self.insert(stack, op.xrp, xrp_force_val, args, True) val = xrp_force_val else: substack = stack + ['apply', addition] if not self.has(substack): if op.xrp.is_mem_proc(): val = op.xrp.apply_mem(args, self, stack) else: val = op.xrp.apply(args) self.insert(substack, op.xrp, val, args) else: if reflip: self.remove(substack) if op.xrp.is_mem_proc(): val = op.xrp.apply_mem(args, self, stack) else: val = op.xrp.apply(args) self.insert(substack, op.xrp, val, args) else: (xrp, val, dbargs, is_obs_noise) = self.get(substack) assert not is_obs_noise else: raise RException('Must apply either a procedure or xrp. Instead got expression %s' % str(op)) elif expr.type == 'function': n = len(expr.vars) new_env = env.spawn_child() bound = {} for i in range(n): # Bind variables bound[expr.vars[i]] = True procedure_body = expr.body.replace(new_env, bound) val = Procedure(expr.vars, procedure_body, env) elif expr.type == '=': (val1, val2) = self.binary_op_evaluate(expr, env, reflip, stack) val = val1.__eq__(val2) elif expr.type == '<': (val1, val2) = self.binary_op_evaluate(expr, env, reflip, stack) val = val1.__lt__(val2) elif expr.type == '>': (val1, val2) = self.binary_op_evaluate(expr, env, reflip, stack) val = val1.__gt__(val2) elif expr.type == '<=': (val1, val2) = self.binary_op_evaluate(expr, env, reflip, stack) val = val1.__le__(val2) elif expr.type == '>=': (val1, val2) = self.binary_op_evaluate(expr, env, reflip, stack) val = val1.__ge__(val2) elif expr.type == '&': vals = self.children_evaluate(expr, env, reflip, stack) andval = BoolValue(True) for x in vals: andval = andval.__and__(x.bool) val = andval elif expr.type == '^': vals = self.children_evaluate(expr, env, reflip, stack) xorval = BoolValue(True) for x in vals: xorval = xorval.__xor__(x.bool) val = xorval elif expr.type == '|': vals = self.children_evaluate(expr, env, reflip, stack) orval = BoolValue(False) for x in vals: orval = orval.__or__(x.bool) val = orval elif expr.type == '~': negval = self.evaluate_recurse(expr.children[0], env, reflip, stack, ['neg']) val = negval.__inv__() elif expr.type == '+': vals = self.children_evaluate(expr, env, reflip, stack) sum_val = NatValue(0) for x in vals: sum_val = sum_val.__add__(x) val = sum_val elif expr.type == '-': val1 = self.evaluate_recurse(expr.children[0], env, reflip, stack , ['sub0']) val2 = self.evaluate_recurse(expr.children[1], env, reflip, stack , ['sub1']) val = val1.__sub__(val2) elif expr.type == '*': vals = self.children_evaluate(expr, env, reflip, stack) prod_val = NatValue(1) for x in vals: prod_val = prod_val.__mul__(x) val = prod_val elif expr.type == '/': val1 = self.evaluate_recurse(expr.children[0], env, reflip, stack , ['div0']) val2 = self.evaluate_recurse(expr.children[1], env, reflip, stack , ['div1']) val = val1.__div__(val2) else: raise RException('Invalid expression type %s' % expr.type) return val def unevaluate(self, uneval_stack, args = None): if args is not None: args = tuple(args) to_unevaluate = [] for tuple_stack in self.db: to_unevaluate.append(tuple_stack) for tuple_stack in self.db_noise: to_unevaluate.append(tuple_stack) to_delete = [] for tuple_stack in to_unevaluate: stack = list(tuple_stack) if len(stack) >= len(uneval_stack) and stack[:len(uneval_stack)] == uneval_stack: if args is None: to_delete.append(tuple_stack) else: assert len(stack) > len(uneval_stack) if stack[len(uneval_stack)] != args: to_delete.append(tuple_stack) for tuple_stack in to_delete: self.remove(tuple_stack) def save(self): self.log = [] self.uneval_p = 0 self.eval_p = 0 def restore(self): self.log.reverse() for (type, stack, xrp, value, args, is_obs_noise) in self.log: if type == 'insert': self.remove(stack, False) else: assert type == 'remove' self.insert(stack, xrp, value, args, is_obs_noise, False) def reflip(self, stack): (xrp, val, args, is_obs_noise) = self.get(stack) #debug = True debug = False old_p = self.p old_to_new_q = - math.log(len(self.db)) if debug: print "old_db", self self.save() self.remove(stack) if xrp.is_mem_proc(): new_val = xrp.apply(args, list(stack)) else: new_val = xrp.apply(args) self.insert(stack, xrp, new_val, args) if debug: print "\nCHANGING ", stack, "\n TO : ", new_val, "\n" if val == new_val: return self.rerun(False) new_p = self.p new_to_old_q = self.uneval_p - math.log(len(self.db)) old_to_new_q += self.eval_p if debug: print "new db", self, \ "\nq(old -> new) : ", old_to_new_q, \ "q(new -> old) : ", new_to_old_q, \ "p(old) : ", old_p, \ "p(new) : ", new_p, \ 'log transition prob : ', new_p + new_to_old_q - old_p - old_to_new_q , "\n" if old_p * old_to_new_q > 0: p = rrandom.random.random() if new_p + new_to_old_q - old_p - old_to_new_q < math.log(p): self.restore() if debug: print 'restore' self.rerun(False) if debug: print "new db", self print "\n-----------------------------------------\n" def __str__(self): string = 'DB with state:' string += '\n Regular Flips:' for s in self.db: string += '\n %s <- %s' % (self.db[s][1].val, s) string += '\n Observe Flips:' for s in self.db_noise: string += '\n %s <- %s' % (self.db_noise[s][1].val, s) return string def __contains__(self, stack): return self.has(self, stack) def __getitem__(self, stack): return self.get(self, stack)