Exemplo n.º 1
0
    def ddnll(self, data = None, params = None):
        if any((not isinstance(d, self.Data) for d in data)):
            raise Exception('Data type not understood')
        if any((None in d.values for d in data)):
            raise Exception('Data.values should be fully determined.')

        if params is not None:
            self.params = params

        ndata = len(data)
        ddnll = np.empty((ndata, self.nparams, self.nparams))
        for i, d in enumerate(data):
            # self.feats = d.feats
            # self.values = d.values

            for f in self.factors:
                self.vp_node[f].make_table(d.feats)
            # log.log_params(self)
            # log.log_tables(self)
            algo.message_passing(self, 'sum-product')

            Efeats = [ list(self.vp_node[f].Efeats(d)) for f in self.factors ]
            ddnll[i] = np.sum([
                np.outer(ef1, ef2) for ef2 in Efeats for ef1 in Efeats
            ], axis=0)

        return ddnll
Exemplo n.º 2
0
    def other_pr(self, v, data):
        _data = data if isinstance(data, list) else [data]
        if any((not isinstance(d, self.Data) for d in _data)):
            raise Exception('Data type not understood')

        ndata = len(_data)
        pr = [None]*ndata
        for i, d in enumerate(_data):
            self.values = d.values
            # self.feats = d.feats

            for f in self.factors:
                self.vp_node[f].make_table(d.feats)
            algo.message_passing(self, 'sum-product')
            pr[i] = self.pr(v.vertex)

        return pr if isinstance(data, list) else pr[0]
Exemplo n.º 3
0
    def viterbi(self, data = None, params = None):
        _data = data if isinstance(data, list) else [data]
        if any((not isinstance(d, self.Data) for d in _data)):
            raise Exception('Data type not understood')

        if params is not None:
            self.params = params

# TODO copy input and add new values
        ndata = len(_data)
        vit = [None]*ndata
        for i, d in enumerate(_data):
            self.values = d.values
            # self.feats = d.feats

            for f in self.factors:
                self.vp_node[f].make_table(d.feats)
            algo.message_passing(self, 'max-product')
            vit[i] = d._replace(values=self.argmax())

        return vit if isinstance(data, list) else vit[0]
Exemplo n.º 4
0
    def nll(self, data = None, params = None):
        if any((not isinstance(d, self.Data) for d in data)):
            raise Exception('Data type not understood')
        if any((None in d.values for d in data)):
            raise Exception('Data.values should be fully determined.')

        if params is not None:
            self.params = params

        ndata = len(data)
        nll = np.empty(ndata)
        for i, d in enumerate(data):
            # self.feats = d.feats
            # self.values = d.values

            for f in self.factors:
                self.vp_node[f].make_table(d.feats)
            # log.log_params(self)
            # log.log_tables(self)
            algo.message_passing(self, 'sum-product')

            nll[i] = self.logZ - np.sum( self.vp_node[f].log_value(d) for f in self.factors )
        return nll
Exemplo n.º 5
0
    F2.table = [[ 10, 1 ],
                [ 1, 10 ]]

# F3 prefers if V2 and V3 are different
    F3.table = [[ 1, 10 ],
                [ 10, 1 ]]

    fg.make()
    return fg

def simple_tabgraph():
    return make_tabgraph(vfun=simple_variables)

def domain_tabgraph():
    return make_tabgraph(vfun=domain_variables)

if __name__ == '__main__':
    fmt = '%(levelname)s @%(lineno)d:%(filename)s - %(funcName)s(): %(message)s'
    fmt = '%(asctime)s %(levelname)s @%(lineno)d:%(filename)s - %(funcName)s(): %(message)s'
    logging.basicConfig(filename='log.tabgraph.log',
                        filemode='w',
                        format=fmt,
                        level=logging.DEBUG)

    fg = simple_tabgraph()
    message_passing(fg, 'max-product', 'sum-product')

    print 'max:    {}'.format(fg.max())
    print 'argmax: {}'.format(fg.argmax())

Exemplo n.º 6
0
 def test_domain_viterbi(self):
     e = 1e-5
     fg = tabgraph.domain_tabgraph()
     algo.message_passing(fg, 'sum-product', 'max-product')
     self.assertTrue(1000. - e < fg.max() < 1000. + e)
     self.assertEqual(fg.argmax(), ['This', 'Code', 'Rules'])
Exemplo n.º 7
0
 def test_simple_viterbi(self):
     e = 1e-5
     fg = tabgraph.simple_tabgraph()
     algo.message_passing(fg, 'sum-product', 'max-product')
     self.assertTrue(1000. - e < fg.max() < 1000. + e)
     self.assertEqual(fg.argmax(), [0, 0, 1])