def __init__(self, latent, data, kernel_config, seed=None, fixed_k = False, relation_class = None, threadpool = None): # FIXME add seed print "FIXED_K=", fixed_k # create the model self.rng = pyirm.RNG() if seed != None: pyirm.set_seed(self.rng, seed) if relation_class == None: relation_class = pyirmutil.Relation print "Running with relation_class=", relation_class self.model = irmio.create_model_from_data(data, rng=self.rng, fixed_k = fixed_k, relation_class=relation_class) irmio.set_model_latent(self.model, latent, self.rng) self.iters = 0 self.kernel_config = kernel_config self.PT = False if len(kernel_config) == 1 and kernel_config[0][0] == "parallel_tempering": self.PT = True self.chain_states = [] # create the chain states for t in kernel_config[0][1]['temps']: self.chain_states.append(irmio.get_latent(self.model)) self.threadpool = threadpool
def do_inference(irm_model, rng, kernel_config, iteration, reverse=False, states_at_temps = None, threadpool = None): """ By default we do all domains, all relations. We assume a homogeneous model for the moment. The way values are returned from PT here is an abomination, and should be resolved at some point """ step = 1 res = {'kernel_times' : []} if reverse: step = -1 for kernel_name, params in kernel_config[::step]: t1 = time.time() if kernel_name == 'conj_gibbs': for domain_name, domain_inf in irm_model.domains.iteritems(): gibbs.gibbs_sample_type(domain_inf, rng, params.get("impotent", False)) elif kernel_name == 'fixed_gibbs': for domain_name, domain_inf in irm_model.domains.iteritems(): gibbs.gibbs_sample_fixed_k(domain_inf, rng, params.get("impotent", False)) elif kernel_name == 'nonconj_gibbs': for domain_name, domain_inf in irm_model.domains.iteritems(): gibbs.gibbs_sample_type_nonconj(domain_inf, params.get("M", 10), rng, params.get("impotent", False), threadpool=threadpool) elif kernel_name == "slice_params": for relation_name, relation in irm_model.relations.iteritems(): relation.apply_comp_kernel("slice_sample", rng, params) elif kernel_name == "continuous_mh_params": for relation_name, relation in irm_model.relations.iteritems(): relation.apply_comp_kernel("continuous_mh", rng, params) elif kernel_name == "tempered_transitions": temps = params['temps'] subkernels = params['subkernels'] kernels.tempered_transitions(irm_model, rng, temps, irmio.get_latent, lambda x, y : irmio.set_model_latent(x, y, rng), model.IRM.set_temp, lambda x, y, r: do_inference(x, y, subkernels, iteration, r)) elif kernel_name == "parallel_tempering": temps = params['temps'] subkernels = params['subkernels'] if len(states_at_temps) != len(temps): raise Exception("Insufficient latent states") states_at_temps = kernels.parallel_tempering(irm_model, states_at_temps, rng, temps, irmio.get_latent, lambda x, y : irmio.set_model_latent(x, y, rng), model.IRM.set_temp, lambda x, y: do_inference(x, y, subkernels, iteration)) irmio.set_model_latent(irm_model, states_at_temps[0], rng) res = states_at_temps elif kernel_name == "anneal": temp_sched = params['anneal_sched'] subkernels = params['subkernels'] # i know this is gross, I don't care sub_res = kernels.anneal(irm_model, rng, temp_sched, iteration, model.IRM.set_temp, lambda x, y: do_inference(x, y, subkernels, iteration, threadpool=threadpool)) for v in sub_res['kernel_times']: res['kernel_times'].append(v) elif kernel_name == "domain_hp_grid": grid = params['grid'] kernels.domain_hp_grid(irm_model, rng, grid) elif kernel_name == "relation_hp_grid": grids = params['grids'] kernels.relation_hp_grid(irm_model, rng, grids, threadpool) else: raise Exception("Malformed kernel config, unknown kernel %s" % kernel_name) t2 = time.time() res['kernel_times'].append((kernel_name, t2-t1)) print "kernels:", kernel_name, "%3.2f sec" % (t2-t1) return res