示例#1
0
文件: kb.py 项目: Frostie2/zamia-ai
class AIKB(object):

    def __init__(self, kbname='kb'):

        #
        # prepare our lightweight sparql wrapper
        #

        self.query_prefixes = ''

        #
        # set up graph store
        #

        config = misc.load_config('.airc')

        # DB, SPARQLAlchemyStore

        db_url = config.get('db', 'url')

        self.sas = SPARQLAlchemyStore(db_url, kbname, echo=False)

        self.endpoints = {} # host name -> LDF endpoint

    def register_prefix(self, prefix, uri):
        self.query_prefixes += "PREFIX %s: <%s>\n" % (prefix, uri)
        self.sas.register_prefix(prefix, uri)

    def register_endpoint (self, endpoint, uri):
        self.endpoints[endpoint] = uri

    def register_alias (self, alias, uri):
        self.sas.register_alias (alias, uri)

    def register_graph(self, c):

        #FIXME: implement dump functions in sparqlalchemy
        pass

        # # Bind a few prefix/namespace pairs for more readable output

        # g = self.graph.get_context(c)
        # for p in COMMON_PREFIXES:
        #     g.bind(p, rdflib.Namespace(COMMON_PREFIXES[p]))

    def close (self):
        # self.graph.close()
        pass

    def clear_graph(self, context):
        self.sas.clear_graph(context)
        # query = """
        #         CLEAR GRAPH <%s>
        #         """ % (context)
        # self.sparql(query)

    def clear_all_graphs (self):
        self.sas.clear_all_graphs()

    def dump(self, fn, format='n3'):

        raise Exception ('FIXME: implement dump functions in sparqlalchemy')
        # print
        # print 'dump', fn
        # print
        # print list(self.graph.contexts())

        # self.graph.serialize(destination=fn, format=format)

    def dump_graph(self, context, fn, format='n3'):

        raise Exception ('FIXME: implement dump functions in sparqlalchemy')
        # g = self.graph.get_context(context)

        # g.serialize(destination=fn, format='n3')

    def parse (self, context, format, data):
        self.sas.parse(format=format, data=data, context=context)

    def parse_file (self, context, format, fn):
        self.sas.parse(fn, format=format, context=context)

    def addN (self, quads):
        self.sas.addN(quads)

    def remove (self, quad):
        self.sas.remove(quad)

    def filter_quads(self, s=None, p=None, o=None, context=None):
        return self.sas.filter_quads(s=s, p=p, o=o, context=context)

    def resolve_aliases_prefixes(self, resource):
        return self.sas.resolve_shortcuts(resource)

    def addN_resolve (self, quads):

        quads_resolved = []

        for s,p,o,c in quads:
            quads_resolved.append((self.resolve_aliases_prefixes(s),
                                   self.resolve_aliases_prefixes(p),
                                   self.resolve_aliases_prefixes(o), c))


        self.addN(quads_resolved)

    #
    # local sparql queries
    #

    def sparql(self, query):

        raise Exception ('FIXME: sparql update queries not implemented yet.')

        # query  = self.query_prefixes + query

        # return self.graph.update(query)

    def query(self, query):

        query  = self.query_prefixes + query
        # logging.debug (query)

        return self.sas.query(query)

    def query_algebra(self, algebra):
        return self.sas.query_algebra(algebra)

    #
    # remote sparql utilities
    #

    def remote_sparql(self, endpoint, query, user=None, passwd=None, response_format='application/sparql-results+json'):

        if user:
            auth   = HTTPDigestAuth(user, passwd)
        else:
            auth   = None

        query  = self.query_prefixes + query
        # print query

        response = requests.post(
          endpoint,
          # data    = '',
          params  = {'query': query},
          headers = {"accept": response_format},
          auth    = auth
        )
        return response

    def remote_query(self, endpoint, query, user=None, passwd=None):

        response = self.remote_sparql(endpoint, query, user=user, passwd=passwd)

        return json.loads(response.text.decode("utf-8"))

    #
    # LDF support
    #

    def ldf_mirror(self, res_paths, context):

        ldfmirror = LDFMirror (self.sas, self.endpoints)

        ldfmirror.mirror (res_paths, context)
class TestSPARQLAlchemy(unittest.TestCase):
    def setUp(self):

        config = misc.load_config('.airc')

        #
        # db, store
        #

        db_url = config.get('db', 'url')
        # db_url = 'sqlite:///tmp/foo.db'

        self.sas = SPARQLAlchemyStore(db_url, 'unittests', echo=True)
        self.context = u'http://example.com'

        #
        # import triples to test on
        #

        self.sas.clear_all_graphs()

        samplefn = 'tests/triples.n3'

        with codecs.open(samplefn, 'r', 'utf8') as samplef:

            data = samplef.read()

            self.sas.parse(data=data, context=self.context, format='n3')

    # @unittest.skip("temporarily disabled")
    def test_import(self):
        self.assertEqual(len(self.sas), NUM_SAMPLE_ROWS)

    # @unittest.skip("temporarily disabled")
    def test_clear_graph(self):
        self.assertEqual(len(self.sas), NUM_SAMPLE_ROWS)

        # add a triple belonging to a different context
        foo_context = u'http://foo.com'
        self.sas.addN([(u'foo', u'bar', u'baz',
                        rdflib.Graph(identifier=foo_context))])
        self.assertEqual(len(self.sas), NUM_SAMPLE_ROWS + 1)

        # clear context that does not exist
        self.sas.clear_graph(u'http://bar.com')
        self.assertEqual(len(self.sas), NUM_SAMPLE_ROWS + 1)

        # clear context that does exist, make sure other triples survive
        self.sas.clear_graph(self.context)
        self.assertEqual(len(self.sas), 1)

        # add a triple belonging to yet another context
        foo_context = u'http://baz.com'
        self.sas.addN([(u'foo', u'bar', u'baz',
                        rdflib.Graph(identifier=foo_context))])
        self.assertEqual(len(self.sas), 2)

        # test clear_all_graphs

        self.sas.clear_all_graphs()
        self.assertEqual(len(self.sas), 0)

    # @unittest.skip("temporarily disabled")
    def test_query_optional(self):

        sparql = """
                 PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
                 PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
                 PREFIX schema: <http://schema.org/>
                 PREFIX dbr: <http://dbpedia.org/resource/>
                 PREFIX dbo: <http://dbpedia.org/ontology/>
                 SELECT ?leader ?label ?leaderobj 
                 WHERE {
                     ?leader rdfs:label ?label. 
                     ?leader rdf:type schema:Person.
                     OPTIONAL {?leaderobj dbo:leader ?leader}
                 }
                 """

        res = self.sas.query(sparql)

        self.assertEqual(len(res), 24)

        for row in res:
            s = ''
            for v in res.vars:
                s += ' %s=%s' % (v, row[v])
            logging.debug('sparql result row: %s' % s)

    # @unittest.skip("temporarily disabled")
    def test_query_limit(self):

        sparql = """
                 PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
                 PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
                 PREFIX schema: <http://schema.org/>
                 PREFIX dbr: <http://dbpedia.org/resource/>
                 PREFIX dbo: <http://dbpedia.org/ontology/>
                 SELECT ?leader ?label ?leaderobj 
                 WHERE {
                     ?leader rdfs:label ?label. 
                     ?leader rdf:type schema:Person.
                     OPTIONAL {?leaderobj dbo:leader ?leader}
                 }
                 LIMIT 1
                 """

        res = self.sas.query(sparql)

        self.assertEqual(len(res), 1)

    # @unittest.skip("temporarily disabled")
    def test_query_filter(self):

        sparql = """
                 PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
                 PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
                 PREFIX schema: <http://schema.org/>
                 PREFIX dbr: <http://dbpedia.org/resource/>
                 PREFIX dbo: <http://dbpedia.org/ontology/>
                 SELECT ?leader ?label ?leaderobj 
                 WHERE {
                     ?leader rdfs:label ?label. 
                     ?leader rdf:type schema:Person.
                     OPTIONAL {?leaderobj dbo:leader ?leader}
                     FILTER (lang(?label) = 'de')
                 }
                 """

        res = self.sas.query(sparql)

        self.assertEqual(len(res), 2)

        for row in res:
            s = ''
            for v in res.vars:
                s += ' %s=%s' % (v, row[v])
            logging.debug('sparql result row: %s' % s)

        sparql = """
                 PREFIX rdfs:   <http://www.w3.org/2000/01/rdf-schema#>
                 PREFIX rdf:    <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
                 PREFIX schema: <http://schema.org/>
                 PREFIX dbr:    <http://dbpedia.org/resource/>
                 PREFIX dbo:    <http://dbpedia.org/ontology/>
                 PREFIX owl:    <http://www.w3.org/2002/07/owl#> 
                 PREFIX wdt:    <http://www.wikidata.org/prop/direct/> 
                 SELECT ?label ?birthPlace ?wdgenderlabel
                 WHERE {
                     ?chancellor rdfs:label ?label.
                     ?chancellor dbo:birthPlace ?birthPlace.
                     ?chancellor rdf:type schema:Person.
                     ?birthPlace rdf:type dbo:Settlement.
                     ?chancellor owl:sameAs ?wdchancellor.
                     ?wdchancellor wdt:P21 ?wdgender.
                     ?wdgender rdfs:label ?wdgenderlabel.
                     FILTER (lang(?label) = 'de')
                     FILTER (lang(?wdgenderlabel) = 'de')
                 }"""

        res = self.sas.query(sparql)

        self.assertEqual(len(res), 2)

        for row in res:
            s = ''
            for v in res.vars:
                s += ' %s=%s' % (v, row[v])
            logging.debug('sparql result row: %s' % s)

        sparql = """
                 PREFIX rdfs:   <http://www.w3.org/2000/01/rdf-schema#>
                 PREFIX rdf:    <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
                 PREFIX schema: <http://schema.org/>
                 PREFIX dbr:    <http://dbpedia.org/resource/>
                 PREFIX dbo:    <http://dbpedia.org/ontology/>
                 PREFIX dbp:    <http://dbpedia.org/property/>
                 PREFIX owl:    <http://www.w3.org/2002/07/owl#> 
                 PREFIX wdt:    <http://www.wikidata.org/prop/direct/> 
                 SELECT ?label ?leaderof
                 WHERE {
                     ?chancellor rdfs:label ?label.
                     ?chancellor rdf:type schema:Person.
                     ?chancellor dbp:office dbr:Chancellor_of_Germany.
                     OPTIONAL { ?leaderof dbo:leader ?chancellor }.
                     FILTER (lang(?label) = 'de')
                 }"""

        res = self.sas.query(sparql)

        self.assertEqual(len(res), 1)

        for row in res:
            s = ''
            for v in res.vars:
                s += ' %s=%s' % (v, row[v])
            logging.debug('sparql result row: %s' % s)

    # @unittest.skip("temporarily disabled")
    def test_distinct(self):

        sparql = """
                 PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
                 PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
                 PREFIX schema: <http://schema.org/>
                 PREFIX dbr: <http://dbpedia.org/resource/>
                 PREFIX dbo: <http://dbpedia.org/ontology/>
                 SELECT DISTINCT ?leader  
                 WHERE {
                     ?leader rdfs:label ?label. 
                     ?leader rdf:type schema:Person.
                 }
                 """

        res = self.sas.query(sparql)

        self.assertEqual(len(res), 2)

        for row in res:
            s = ''
            for v in res.vars:
                s += ' %s=%s' % (v, row[v])
            logging.debug('sparql result row: %s' % s)

    # @unittest.skip("temporarily disabled")
    def test_dates(self):

        sparql = """
                 PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
                 PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
                 PREFIX schema: <http://schema.org/>
                 PREFIX dbr: <http://dbpedia.org/resource/>
                 PREFIX dbo: <http://dbpedia.org/ontology/>
                 PREFIX hal: <http://hal.zamia.org/kb/> 
                 SELECT ?temp_min ?temp_max ?precipitation ?clouds ?icon
                 WHERE {
                     ?wev hal:dt_end ?dt_end. 
                     ?wev hal:dt_start ?dt_start.
                     ?wev hal:location dbr:Stuttgart.
                     ?wev hal:temp_min ?temp_min   .
                     ?wev hal:temp_max ?temp_max   .
                     ?wev hal:precipitation ?precipitation .
                     ?wev hal:clouds ?clouds .
                     ?wev hal:icon ?icon .
                     FILTER (?dt_start >= \"2016-12-04T10:20:13+05:30\"^^xsd:dateTime &&
                             ?dt_end   <= \"2016-12-23T10:20:13+05:30\"^^xsd:dateTime)
                 }
                 """

        res = self.sas.query(sparql)

        self.assertEqual(len(res), 2)

        for row in res:
            s = ''
            for v in res.vars:
                s += ' %s=%s' % (v, row[v])
            logging.debug('sparql result row: %s' % s)

    # @unittest.skip("temporarily disabled")
    def test_filter_quads(self):

        quads = self.sas.filter_quads(None, None, None, self.context)
        self.assertEqual(len(quads), NUM_SAMPLE_ROWS)

        quads = self.sas.filter_quads(
            u'http://dbpedia.org/resource/Helmut_Kohl', None, None,
            self.context)
        self.assertEqual(len(quads), 73)

        quads = self.sas.filter_quads(
            u'http://dbpedia.org/resource/Helmut_Kohl',
            u'http://dbpedia.org/ontology/birthPlace', None, self.context)
        self.assertEqual(len(quads), 2)