def ddnll(self, data = None, params = None): if any((not isinstance(d, self.Data) for d in data)): raise Exception('Data type not understood') if any((None in d.values for d in data)): raise Exception('Data.values should be fully determined.') if params is not None: self.params = params ndata = len(data) ddnll = np.empty((ndata, self.nparams, self.nparams)) for i, d in enumerate(data): # self.feats = d.feats # self.values = d.values for f in self.factors: self.vp_node[f].make_table(d.feats) # log.log_params(self) # log.log_tables(self) algo.message_passing(self, 'sum-product') Efeats = [ list(self.vp_node[f].Efeats(d)) for f in self.factors ] ddnll[i] = np.sum([ np.outer(ef1, ef2) for ef2 in Efeats for ef1 in Efeats ], axis=0) return ddnll
def other_pr(self, v, data): _data = data if isinstance(data, list) else [data] if any((not isinstance(d, self.Data) for d in _data)): raise Exception('Data type not understood') ndata = len(_data) pr = [None]*ndata for i, d in enumerate(_data): self.values = d.values # self.feats = d.feats for f in self.factors: self.vp_node[f].make_table(d.feats) algo.message_passing(self, 'sum-product') pr[i] = self.pr(v.vertex) return pr if isinstance(data, list) else pr[0]
def viterbi(self, data = None, params = None): _data = data if isinstance(data, list) else [data] if any((not isinstance(d, self.Data) for d in _data)): raise Exception('Data type not understood') if params is not None: self.params = params # TODO copy input and add new values ndata = len(_data) vit = [None]*ndata for i, d in enumerate(_data): self.values = d.values # self.feats = d.feats for f in self.factors: self.vp_node[f].make_table(d.feats) algo.message_passing(self, 'max-product') vit[i] = d._replace(values=self.argmax()) return vit if isinstance(data, list) else vit[0]
def nll(self, data = None, params = None): if any((not isinstance(d, self.Data) for d in data)): raise Exception('Data type not understood') if any((None in d.values for d in data)): raise Exception('Data.values should be fully determined.') if params is not None: self.params = params ndata = len(data) nll = np.empty(ndata) for i, d in enumerate(data): # self.feats = d.feats # self.values = d.values for f in self.factors: self.vp_node[f].make_table(d.feats) # log.log_params(self) # log.log_tables(self) algo.message_passing(self, 'sum-product') nll[i] = self.logZ - np.sum( self.vp_node[f].log_value(d) for f in self.factors ) return nll
F2.table = [[ 10, 1 ], [ 1, 10 ]] # F3 prefers if V2 and V3 are different F3.table = [[ 1, 10 ], [ 10, 1 ]] fg.make() return fg def simple_tabgraph(): return make_tabgraph(vfun=simple_variables) def domain_tabgraph(): return make_tabgraph(vfun=domain_variables) if __name__ == '__main__': fmt = '%(levelname)s @%(lineno)d:%(filename)s - %(funcName)s(): %(message)s' fmt = '%(asctime)s %(levelname)s @%(lineno)d:%(filename)s - %(funcName)s(): %(message)s' logging.basicConfig(filename='log.tabgraph.log', filemode='w', format=fmt, level=logging.DEBUG) fg = simple_tabgraph() message_passing(fg, 'max-product', 'sum-product') print 'max: {}'.format(fg.max()) print 'argmax: {}'.format(fg.argmax())
def test_domain_viterbi(self): e = 1e-5 fg = tabgraph.domain_tabgraph() algo.message_passing(fg, 'sum-product', 'max-product') self.assertTrue(1000. - e < fg.max() < 1000. + e) self.assertEqual(fg.argmax(), ['This', 'Code', 'Rules'])
def test_simple_viterbi(self): e = 1e-5 fg = tabgraph.simple_tabgraph() algo.message_passing(fg, 'sum-product', 'max-product') self.assertTrue(1000. - e < fg.max() < 1000. + e) self.assertEqual(fg.argmax(), [0, 0, 1])