def __init__(self, dataset, actions, mode='fixed_order', bounds=None, model=None, fastinf_model_name='perfect'): assert (mode in self.accepted_modes) self.mode = mode self.dataset = dataset self.actions = actions self.bounds = bounds self.fastinf_model_name = fastinf_model_name # Is GIST in the actions? Need to behave differently if so. self.gist_mode = ('gist' in [action.name for action in self.actions]) self.num_obs_vars = len(self.actions) if self.gist_mode: assert (self.actions[0].name == 'gist') self.num_obs_vars = len(self.actions) - 1 + len( self.dataset.classes) if mode == 'random': if model: assert (isinstance(model, RandomModel)) self.model = model else: self.model = RandomModel(len(self.dataset.classes)) elif mode == 'no_smooth' or mode == 'backoff': if model: assert (isinstance(model, NGramModel)) self.model = model else: self.model = NGramModel(dataset, mode) elif mode == 'fixed_order': if model: assert (isinstance(model, FixedOrderModel)) self.model = model else: self.model = FixedOrderModel(dataset) elif mode == 'fastinf': if model: assert (isinstance(model, FastinfModel)) self.model = model else: self.model = FastinfModel(dataset, self.fastinf_model_name, self.num_obs_vars) else: raise RuntimeError("Unknown mode") self.reset() self.orig_p_c = self.get_p_c()
def test(): dataset = Dataset('full_pascal_trainval') fm = FastinfModel(dataset, 'perfect', 20) # NOTE: just took values from a run of the thing prior_correct = [ float(x) for x in "0.050543 0.053053 0.073697 0.038331 0.050954 0.041879 0.16149\ 0.068721 0.10296 0.026837 0.043779 0.087683 0.063447 0.052205\ 0.41049 0.051664 0.014211 0.068361 0.056969 0.05046".split() ] np.testing.assert_almost_equal(fm.p_c, prior_correct, 4) observations = np.zeros(20) taken = np.zeros(20) fm.update_with_observations(taken, observations) np.testing.assert_almost_equal(fm.p_c, prior_correct, 4) observations[5] = 1 taken[5] = 1 fm.update_with_observations(taken, observations) print fm.p_c correct = [ float(x) for x in "0.027355 0.11855 0.027593 0.026851 0.012569 0.98999 0.52232\ 0.017783 0.010806 0.015199 0.0044641 0.02389 0.033602 0.089089\ 0.50297 0.0083272 0.0088274 0.0098522 0.034259 0.0086298".split() ] np.testing.assert_almost_equal(fm.p_c, correct, 4) observations[15] = 0 taken[15] = 1 fm.update_with_observations(taken, observations) correct = [ float(x) for x in "2.73590000e-02 1.19030000e-01 2.75500000e-02 2.68760000e-02 \ 1.23920000e-02 9.90200000e-01 5.25320000e-01 1.76120000e-02 \ 1.05030000e-02 1.52130000e-02 4.26410000e-03 2.38250000e-02 \ 3.36870000e-02 8.96450000e-02 5.04300000e-01 8.71880000e-05 \ 8.82630000e-03 9.55290000e-03 3.43240000e-02 8.44510000e-03".split() ] np.testing.assert_almost_equal(fm.p_c, correct) # reinit_marginals fm.reset() np.testing.assert_equal(fm.p_c, prior_correct) print(fm.cache)