def relation_hp_grid(model, rng, grids, threadpool=None): """ add the ability to have per-relation grids If the grid is 'None', don't do inference """ for relation_name, relation in model.relations.iteritems(): model_name = relation.modeltypestr if relation_name in grids: grid = grids[relation_name] elif model_name in grids: grid = grids[model_name] else: raise RuntimeError("model %s is not in the provided grids" % model_name) if grid == None: continue if isinstance(relation, pyirmutil.Relation): ## THIS IS A TOTAL HACK we should not be dispatching this way ## fix in later version once we obsolte old code def set_func(val): relation.set_hps(val) def get_score(): return relation.total_score() if grid == None: continue gridgibbshps.grid_gibbs(set_func, get_score, grid) else: scores = relation.score_at_hps(grid, threadpool) i = util.sample_from_scores(scores) relation.set_hps(grid[i])
def domain_hp_grid(model, rng, grid): for domain_name, domain in model.domains.iteritems(): def set_func(val): domain.set_hps({'alpha' : val}) def get_score(): return domain.get_prior_score() gridgibbshps.grid_gibbs(set_func, get_score, grid)