class TestAIProlog(unittest.TestCase): def setUp(self): config = misc.load_config('.airc') # # logic DB # self.db = LogicDB(model.url) # # aiprolog environment setup # self.prolog_rt = AIPrologRuntime(self.db) self.parser = AIPrologParser(self.db) self.prolog_rt.set_trace(True) self.db.clear_module(UNITTEST_MODULE) # @unittest.skip("temporarily disabled") def test_tokenize(self): clause = self.parser.parse_line_clause_body( "tokenize (de, 'hallo, welt!', X)") logging.debug('clause: %s' % clause) solutions = self.prolog_rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 1) self.assertEqual(len(solutions[0]['X'].l), 2) # @unittest.skip("temporarily disabled") def test_edit_distance(self): clause = self.parser.parse_line_clause_body( "edit_distance (['hallo', 'welt'], ['hallo', 'springfield'], X)") logging.debug('clause: %s' % clause) solutions = self.prolog_rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 1) self.assertEqual(solutions[0]['X'].f, 1.0)
class TestAIProlog (unittest.TestCase): def setUp(self): config = misc.load_config('.airc') # # logic DB # self.db = LogicDB(model.url) # # aiprolog environment setup # self.prolog_rt = AIPrologRuntime(self.db) self.parser = AIPrologParser(self.db) self.prolog_rt.set_trace(True) self.db.clear_module(UNITTEST_MODULE) # @unittest.skip("temporarily disabled") def test_tokenize(self): clause = self.parser.parse_line_clause_body("tokenize (de, 'hallo, welt!', X)") logging.debug('clause: %s' % clause) solutions = self.prolog_rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual (len(solutions), 1) self.assertEqual (len(solutions[0]['X'].l), 2) # @unittest.skip("temporarily disabled") def test_edit_distance(self): clause = self.parser.parse_line_clause_body("edit_distance (['hallo', 'welt'], ['hallo', 'springfield'], X)") logging.debug('clause: %s' % clause) solutions = self.prolog_rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual (len(solutions), 1) self.assertEqual (solutions[0]['X'].f, 1.0)
class AIKernal(object): def __init__(self, load_all_modules=False): self.config = misc.load_config('.airc') # # database # Session = sessionmaker(bind=model.engine) self.session = Session() # # TensorFlow (deferred, as tf can take quite a bit of time to set up) # self.tf_session = None self.nlp_model = None # # module management, setup # self.modules = {} self.initialized_modules = set() s = self.config.get('semantics', 'modules') self.all_modules = list(map(lambda s: s.strip(), s.split(','))) sys.path.append('modules') # # AIProlog parser, runtime # db_url = self.config.get('db', 'url') self.db = LogicDB(db_url) self.aip_parser = AIPrologParser(self) self.rt = AIPrologRuntime(self.db) self.dummyloc = SourceLocation('<rt>') # # alignment / word2vec (on-demand model loading) # self.w2v_model = None self.w2v_lang = None self.w2v_all_utterances = [] # # load modules, if requested # if load_all_modules: for mn2 in self.all_modules: self.load_module(mn2) self.init_module(mn2) # FIXME: this will work only on the first call def setup_tf_model(self, mode, load_model, ini_fn, global_step=0): if not self.tf_session: import tensorflow as tf # setup config to use BFC allocator config = tf.ConfigProto() # config.gpu_options.allocator_type = 'BFC' self.tf_session = tf.Session(config=config) if not self.nlp_model: from nlp_model import NLPModel self.nlp_model = NLPModel(self.session, ini_fn, global_step=global_step) if load_model: self.nlp_model.load_dicts() # we need the inverse dict to reconstruct the output from tensor self.inv_output_dict = { v: k for k, v in viewitems(self.nlp_model.output_dict) } self.tf_model = self.nlp_model.create_tf_model(self.tf_session, mode=mode) self.tf_model.batch_size = 1 self.tf_model.restore(self.tf_session, self.nlp_model.model_fn) def clean(self, module_names, clean_all, clean_logic, clean_discourses, clean_cronjobs): for module_name in module_names: if clean_logic or clean_all: logging.info('cleaning logic for %s...' % module_name) if module_name == 'all': self.db.clear_all_modules() else: self.db.clear_module(module_name) if clean_discourses or clean_all: logging.info('cleaning discourses for %s...' % module_name) if module_name == 'all': self.session.query(model.DiscourseRound).delete() else: self.session.query(model.DiscourseRound).filter( model.DiscourseRound.module == module_name).delete() if clean_cronjobs or clean_all: logging.info('cleaning cronjobs for %s...' % module_name) if module_name == 'all': self.session.query(model.Cronjob).delete() else: self.session.query(model.Cronjob).filter( model.Cronjob.module == module_name).delete() self.session.commit() def load_module(self, module_name): if module_name in self.modules: return self.modules[module_name] logging.debug("loading module '%s'" % module_name) # fp, pathname, description = imp.find_module(module_name, ['modules']) fp, pathname, description = imp.find_module(module_name) # print fp, pathname, description m = None try: m = imp.load_module(module_name, fp, pathname, description) self.modules[module_name] = m # print m # print getattr(m, '__all__', None) # for name in dir(m): # print name for m2 in getattr(m, 'DEPENDS'): self.load_module(m2) if hasattr(m, 'CRONJOBS'): # update cronjobs in db old_cronjobs = set() for cronjob in self.session.query(model.Cronjob).filter( model.Cronjob.module == module_name): old_cronjobs.add(cronjob.name) new_cronjobs = set() for name, interval, f in getattr(m, 'CRONJOBS'): logging.debug('registering cronjob %s' % name) cj = self.session.query(model.Cronjob).filter( model.Cronjob.module == module_name, model.Cronjob.name == name).first() if not cj: cj = model.Cronjob(module=module_name, name=name, last_run=0) self.session.add(cj) cj.interval = interval new_cronjobs.add(cj.name) for cjn in old_cronjobs: if cjn in new_cronjobs: continue self.session.query(model.Cronjob).filter( model.Cronjob.module == module_name, model.Cronjob.name == cjn).delete() self.session.commit() if hasattr(m, 'init_module'): initializer = getattr(m, 'init_module') initializer(self) except: logging.error('failed to load module "%s"' % module_name) logging.error(traceback.format_exc()) sys.exit(1) finally: # Since we may exit via an exception, close fp explicitly. if fp: fp.close() return m def init_module(self, module_name, run_trace=False): if module_name in self.initialized_modules: return logging.debug("initializing module '%s'" % module_name) self.initialized_modules.add(module_name) m = self.load_module(module_name) if not m: raise Exception('init_module: module "%s" not found.' % module_name) for m2 in getattr(m, 'DEPENDS'): self.init_module(m2) prolog_s = u'init(\'%s\')' % (module_name) c = self.aip_parser.parse_line_clause_body(prolog_s) self.rt.set_trace(run_trace) solutions = self.rt.search(c) def compile_module(self, module_name): m = self.modules[module_name] # clear module, delete old NLP training data self.db.clear_module(module_name, commit=True) self.session.query(model.TrainingData).filter( model.TrainingData.module == module_name).delete() self.session.query(model.TestCase).filter( model.TestCase.module == module_name).delete() self.session.query(model.NERData).filter( model.NERData.module == module_name).delete() # extract new training data for this module train_ds = [] tests = [] ner = {} if hasattr(m, 'nlp_train'): # training_data_cnt = 0 logging.info('module %s python training data extraction...' % module_name) nlp_train = getattr(m, 'nlp_train') train_ds.extend(nlp_train(self)) if hasattr(m, 'nlp_test'): logging.info('module %s python test case extraction...' % module_name) nlp_test = getattr(m, 'nlp_test') nlp_tests = nlp_test(self) tests.extend(nlp_tests) if hasattr(m, 'AIP_SOURCES'): logging.info('module %s AIP training data extraction...' % module_name) for inputfn in m.AIP_SOURCES: ds, ts, ne = self.aip_parser.compile_file( 'modules/%s/%s' % (module_name, inputfn), module_name) train_ds.extend(ds) tests.extend(ts) for lang in ne: if not lang in ner: ner[lang] = {} for cls in ne[lang]: if not cls in ner[lang]: ner[lang][cls] = {} for entity in ne[lang][cls]: ner[lang][cls][entity] = ne[lang][cls][entity] logging.info( 'module %s training data extraction done. %d training samples, %d tests' % (module_name, len(train_ds), len(tests))) # put training data into our DB td_set = set() td_list = [] for utt_lang, contexts, i, resp, loc_fn, loc_line, loc_col, prio in train_ds: inp = copy(contexts) inp.extend(i) inp_json = json.dumps(inp) resp_json = json.dumps(resp) # utterance = u' '.join(map(lambda c: text_type(c), contexts)) # if utterance: # utterance += u' ' # utterance += u' '.join(i) utterance = u' '.join(i) k = utt_lang + '#0#' + '#' + inp_json + '#' + resp_json if not k in td_set: td_set.add(k) td_list.append( model.TrainingData( lang=utt_lang, module=module_name, utterance=utterance, inp=inp_json, resp=resp_json, prio=prio, loc_fn=loc_fn, loc_line=loc_line, loc_col=loc_col, )) logging.info( 'module %s training data conversion done. %d unique training samples.' % (module_name, len(td_list))) start_time = time.time() logging.info(u'bulk saving to db...') self.session.bulk_save_objects(td_list) self.session.commit() logging.info(u'bulk saving to db... done. Took %fs.' % (time.time() - start_time)) # put test data into our DB td_list = [] for name, lang, prep, rounds, loc_fn, loc_line, loc_col in tests: prep_json = prolog_to_json(prep) rounds_json = json.dumps(rounds) td_list.append( model.TestCase(lang=lang, module=module_name, name=name, prep=prep_json, rounds=rounds_json, loc_fn=loc_fn, loc_line=loc_line, loc_col=loc_col)) logging.info('module %s test data conversion done. %d tests.' % (module_name, len(td_list))) start_time = time.time() logging.info(u'bulk saving to db...') self.session.bulk_save_objects(td_list) self.session.commit() logging.info(u'bulk saving to db... done. Took %fs.' % (time.time() - start_time)) # put NER data into our DB # import pdb; pdb.set_trace() ner_list = [] for lang in ner: for cls in ner[lang]: for entity in ner[lang][cls]: ner_list.append( model.NERData(lang=lang, module=module_name, cls=cls, entity=entity, label=ner[lang][cls][entity])) logging.info('module %s NER data conversion done. %d rows.' % (module_name, len(ner_list))) start_time = time.time() logging.info(u'bulk saving to db...') self.session.bulk_save_objects(ner_list) self.session.commit() logging.info(u'bulk saving to db... done. Took %fs.' % (time.time() - start_time)) self.session.commit() def compile_module_multi(self, module_names): for module_name in module_names: if module_name == 'all': for mn2 in self.all_modules: self.load_module(mn2) self.compile_module(mn2) else: self.load_module(module_name) self.compile_module(module_name) self.session.commit() # _IGNORE_CONTEXT_KEYS = set([ 'user', 'lang', 'tokens', 'time', 'prev', 'resp' ]) def _compute_net_input(self, res, cur_context): solutions = self.rt.search_predicate('tokens', [cur_context, '_1'], env=res) tokens = solutions[0]['_1'].l solutions = self.rt.search_predicate('context', [cur_context, '_2', '_3'], env=res) d = {} for s in solutions: k = s['_2'] if not isinstance(k, Predicate): continue k = k.name v = s['_3'] if isinstance(v, Predicate): v = v.name elif isinstance(v, StringLiteral): v = v.s else: v = text_type(v) d[k] = v # import pdb; pdb.set_trace() inp = [] for t in reversed(tokens): inp.insert(0, t.s) for k in sorted(list(d)): inp.insert(0, [k, d[k]]) return inp def find_prev_context(self, user, env={}): pc = None ctxid = 0 # logging.debug ('find_prev_context: user=%s' % user) for s in self.rt.search_predicate('user', ['_1', Predicate(user)], env=env): cid = int(s['_1'].name[7:]) if not pc or cid > ctxid: pc = s['_1'] # logging.debug ('find_prev_context: s=%s, pc=%s' % (unicode(s), unicode(pc))) return pc def _setup_context(self, user, lang, inp, prev_context, prev_res): cur_context = Predicate(do_gensym(self.rt, 'context')) res = {} if ASSERT_OVERLAY_VAR_NAME in prev_res: res[ASSERT_OVERLAY_VAR_NAME] = prev_res[ ASSERT_OVERLAY_VAR_NAME].clone() res = do_assertz( {}, Clause(Predicate('user', [cur_context, Predicate(user)]), location=self.dummyloc), res=res) res = do_assertz( {}, Clause(Predicate('lang', [cur_context, Predicate(lang)]), location=self.dummyloc), res=res) token_literal = ListLiteral(list(map(lambda x: StringLiteral(x), inp))) res = do_assertz({}, Clause(Predicate('tokens', [cur_context, token_literal]), location=self.dummyloc), res=res) currentTime = datetime.datetime.utcnow().replace( tzinfo=pytz.UTC).isoformat() res = do_assertz( {}, Clause(Predicate( 'time', [cur_context, StringLiteral(currentTime)]), location=self.dummyloc), res=res) if prev_context: res = do_assertz({}, Clause(Predicate('prev', [cur_context, prev_context]), location=self.dummyloc), res=res) # copy over all previous context statements to the new one s1s = self.rt.search_predicate('context', [prev_context, '_1', '_2'], env=res) for s1 in s1s: res = do_assertz( {}, Clause(Predicate('context', [cur_context, s1['_1'], s1['_2']]), location=self.dummyloc), res=res) # copy over all previous mem statements to the new one s1s = self.rt.search_predicate('mem', [prev_context, '_1', '_2'], env=res) for s1 in s1s: res = do_assertz({}, Clause(Predicate( 'mem', [cur_context, s1['_1'], s1['_2']]), location=self.dummyloc), res=res) # import pdb; pdb.set_trace() res['C'] = cur_context return res, cur_context def _extract_response(self, cur_context, env): #import pdb; pdb.set_trace() res = [] s2s = self.rt.search_predicate('c_say', [cur_context, '_1'], env=env) for s2 in s2s: if not '_1' in s2: continue res.append(s2['_1'].s) actions = [] s2s = self.rt.search_predicate('c_action', [cur_context, '_1'], env=env) for s2 in s2s: if not '_1' in s2: continue actions.append(list(map(lambda x: text_type(x), s2['_1'].l))) score = 0.0 s2s = self.rt.search_predicate('c_score', [cur_context, '_1'], env=env) for s2 in s2s: if not '_1' in s2: continue score += s2['_1'].f return res, actions, score def _reconstruct_prolog_code(self, acode): todo = [('and', [])] idx = 0 while idx < len(acode): a = acode[idx] if a == 'or(': todo.append(('or', [])) elif a == 'and(': todo.append(('and', [])) elif a == ')': c = todo.pop() todo[len(todo) - 1][1].append(Predicate(c[0], c[1])) else: clause = self.aip_parser.parse_line_clause_body(a) todo[len(todo) - 1][1].append(clause.body) idx += 1 if len(todo) != 1: logging.warn('unbalanced acode detected.') return None c = todo.pop() return Predicate(c[0], c[1]) def test_module(self, module_name, run_trace=False, test_name=None): self.rt.set_trace(run_trace) m = self.modules[module_name] logging.info('running tests of module %s ...' % (module_name)) num_tests = 0 num_fails = 0 for tc in self.session.query( model.TestCase).filter(model.TestCase.module == module_name): if test_name: if tc.name != test_name: logging.info('skipping test %s' % tc.name) continue num_tests += 1 rounds = json.loads(tc.rounds) prep = json_to_prolog(tc.prep) round_num = 0 prev_context = None res = {} for t_in, t_out, test_actions in rounds: test_in = u' '.join(t_in) test_out = u' '.join(t_out) logging.info("nlp_test: %s round %d test_in : %s" % (tc.name, round_num, repr(test_in))) logging.info("nlp_test: %s round %d test_out : %s" % (tc.name, round_num, repr(test_out))) logging.info("nlp_test: %s round %d test_actions: %s" % (tc.name, round_num, repr(test_actions))) #if round_num>0: # import pdb; pdb.set_trace() res, cur_context = self._setup_context( user=TEST_USER, lang=tc.lang, inp=t_in, prev_context=prev_context, prev_res=res) # prep if prep: # import pdb; pdb.set_trace() # self.rt.set_trace(True) for p in prep: solutions = self.rt.search(Clause( None, p, location=self.dummyloc), env=res) if len(solutions) != 1: raise (PrologRuntimeError( 'Expected exactly one solution from preparation code for test "%s", got %d.' % (tc.name, len(solutions)))) res = solutions[0] # inp / resp inp = self._compute_net_input(res, cur_context) # look up code in DB acode = None matching_resp = False for tdr in self.session.query(model.TrainingData).filter( model.TrainingData.lang == tc.lang, model.TrainingData.inp == json.dumps(inp)): if acode: logging.warn( u'%s: more than one acode for test_in "%s" found in DB!' % (tc.name, test_in)) acode = json.loads(tdr.resp) pcode = self._reconstruct_prolog_code(acode) clause = Clause(None, pcode, location=self.dummyloc) solutions = self.rt.search(clause, env=res) # import pdb; pdb.set_trace() for solution in solutions: actual_out, actual_actions, score = self._extract_response( cur_context, solution) # logging.info("nlp_test: %s round %d %s" % (clause.location, round_num, repr(abuf)) ) if len(test_out) > 0: if len(actual_out) > 0: actual_out = u' '.join( tokenize(u' '.join(actual_out), tc.lang)) logging.info( "nlp_test: %s round %d actual_out : %s (score: %f)" % (tc.name, round_num, actual_out, score)) if actual_out != test_out: logging.info( "nlp_test: %s round %d UTTERANCE MISMATCH." % (tc.name, round_num)) continue # no match logging.info( "nlp_test: %s round %d UTTERANCE MATCHED!" % (tc.name, round_num)) # check actions if len(test_actions) > 0: logging.info( "nlp_test: %s round %d actual acts : %s" % (tc.name, round_num, repr(actual_actions))) # print repr(test_actions) actions_matched = True act = None for action in test_actions: for act in actual_actions: # print " check action match: %s vs %s" % (repr(action), repr(act)) if action == act: break if action != act: actions_matched = False break if not actions_matched: logging.info( "nlp_test: %s round %d ACTIONS MISMATCH." % (tc.name, round_num)) continue logging.info( "nlp_test: %s round %d ACTIONS MATCHED!" % (tc.name, round_num)) matching_resp = True res = solution break if matching_resp: break if acode is None: logging.error('failed to find db entry for %s' % json.dumps(inp)) logging.error( u'Error: %s: no training data for test_in "%s" found in DB!' % (tc.name, test_in)) num_fails += 1 break if not matching_resp: logging.error( u'nlp_test: %s round %d no matching response found.' % (tc.name, round_num)) num_fails += 1 break prev_context = cur_context round_num += 1 self.rt.set_trace(False) return num_tests, num_fails def run_tests_multi(self, module_names, run_trace=False, test_name=None): num_tests = 0 num_fails = 0 for module_name in module_names: if module_name == 'all': for mn2 in self.all_modules: self.load_module(mn2) self.init_module(mn2, run_trace=run_trace) n, f = self.test_module(mn2, run_trace=run_trace, test_name=test_name) num_tests += n num_fails += f else: self.load_module(module_name) self.init_module(module_name, run_trace=run_trace) n, f = self.test_module(module_name, run_trace=run_trace, test_name=test_name) num_tests += n num_fails += f return num_tests, num_fails def _process_input_nnet(self, inp, res): solutions = [] logging.debug('_process_input_nnet: %s' % repr(inp)) try: # ok, exact matching has not yielded any results -> use neural network to # generate response(s) x = self.nlp_model.compute_x(inp) # logging.debug("x: %s -> %s" % (utterance, x)) source, source_len, dest, dest_len = self.nlp_model._prepare_batch( [[x, []]], offset=0) # predicted_ids: GreedyDecoder; [batch_size, max_time_step, 1] # BeamSearchDecoder; [batch_size, max_time_step, beam_width] predicted_ids = self.tf_model.predict( self.tf_session, encoder_inputs=source, encoder_inputs_length=source_len) # for seq_batch in predicted_ids: # for k in range(5): # logging.debug('--------- k: %d ----------' % k) # seq = seq_batch[:,k] # for p in seq: # if p == -1: # break # decoded = self.inv_output_dict[p] # logging.debug (u'%s: %s' %(p, decoded)) # extract best codes only acodes = [[]] for p in predicted_ids[0][:, 0]: if p == -1: break decoded = self.inv_output_dict[p] if decoded == u'_EOS': break if decoded == u'__OR__': acodes.append([]) acodes[len(acodes) - 1].append(decoded) # FIXME: for now, we try the first solution only acode = acodes[0] pcode = self._reconstruct_prolog_code(acode) logging.debug('_process_input_nnet: %s' % pcode) clause = Clause(None, pcode, location=self.dummyloc) solutions = self.rt.search(clause, env=res) except: # probably ok (prolog code generated by neural network might not always work) logging.error('EXCEPTION CAUGHT %s' % traceback.format_exc()) return solutions def process_input(self, utterance, utt_lang, user_uri, run_trace=False, do_eliza=True, prev_ctx=None): """ process user input, return score, responses, actions, solutions, context """ prev_context = prev_ctx res = {} tokens = tokenize(utterance, utt_lang) res, cur_context = self._setup_context(user=user_uri, lang=utt_lang, inp=tokens, prev_context=prev_context, prev_res=res) inp = self._compute_net_input(res, cur_context) logging.debug('process_input: %s' % repr(inp)) # # do we have an exact match in our training data for this input? # solutions = [] self.rt.set_trace(run_trace) for tdr in self.session.query(model.TrainingData).filter( model.TrainingData.lang == utt_lang, model.TrainingData.inp == json.dumps(inp)): acode = json.loads(tdr.resp) pcode = self._reconstruct_prolog_code(acode) clause = Clause(None, pcode, location=self.dummyloc) sols = self.rt.search(clause, env=res) if sols: solutions.extend(sols) if not solutions: solutions = self._process_input_nnet(inp, res) # # try dropping the context if we haven't managed to produce a result yet # if not solutions: res, cur_context = self._setup_context(user=user_uri, lang=utt_lang, inp=tokens, prev_context=None, prev_res={}) inp = self._compute_net_input(res, cur_context) solutions = self._process_input_nnet(inp, res) if not solutions and do_eliza: logging.info('producing ELIZA-style response for input %s' % utterance) clause = self.aip_parser.parse_line_clause_body( 'do_eliza(C, %s)' % utt_lang) solutions = self.rt.search(clause, env=res) self.rt.set_trace(False) # # extract highest-scoring responses only: # best_score = 0 best_resps = [] best_actions = [] best_solutions = [] for solution in solutions: actual_resp, actual_actions, score = self._extract_response( cur_context, solution) if score > best_score: best_score = score best_resps = [] best_actions = [] best_solutions = [] if score < best_score: continue best_resps.append(actual_resp) best_actions.append(actual_actions) best_solutions.append(solution) return best_score, best_resps, best_actions, best_solutions, cur_context def run_cronjobs(self, module_name, force=False, run_trace=False): m = self.modules[module_name] if not hasattr(m, 'CRONJOBS'): return self.rt.set_trace(run_trace) for name, interval, f in getattr(m, 'CRONJOBS'): cronjob = self.session.query(model.Cronjob).filter( model.Cronjob.module == module_name, model.Cronjob.name == name).first() t = time.time() next_run = cronjob.last_run + interval if force or t > next_run: logging.debug('running cronjob %s' % name) f(self) cronjob.last_run = t def run_cronjobs_multi(self, module_names, force, run_trace=False): for module_name in module_names: if module_name == 'all': for mn2 in self.all_modules: self.load_module(mn2) self.init_module(mn2) self.run_cronjobs(mn2, force=force, run_trace=run_trace) else: self.load_module(module_name) self.init_module(module_name) self.run_cronjobs(module_name, force=force, run_trace=run_trace) self.session.commit() def train(self, ini_fn, num_steps, incremental): self.setup_tf_model('train', False, ini_fn) self.nlp_model.train(num_steps, incremental) def dump_utterances(self, num_utterances, dictfn, lang, module): dic = None if dictfn: dic = set() with codecs.open(dictfn, 'r', 'utf8') as dictf: for line in dictf: parts = line.strip().split(';') if len(parts) != 2: continue dic.add(parts[0]) all_utterances = [] req = self.session.query( model.TrainingData).filter(model.TrainingData.lang == lang) if module and module != 'all': req = req.filter(model.TrainingData.module == module) for dr in req: if not dic: all_utterances.append(dr.utterance) else: # is at least one word not covered by our dictionary? unk = False for t in tokenize(dr.utterance): if not t in dic: # print u"unknown word: %s in %s" % (t, dr.utterance) unk = True dic.add(t) break if not unk: continue all_utterances.append(dr.utterance) utts = set() if num_utterances > 0: while (len(utts) < num_utterances): i = random.randrange(0, len(all_utterances)) utts.add(all_utterances[i]) else: for utt in all_utterances: utts.add(utt) for utt in utts: print(utt) def setup_align_utterances(self, lang): if self.w2v_model and self.w2v_lang == lang: return logging.debug('loading all utterances from db...') self.w2v_all_utterances = [] req = self.session.query( model.TrainingData).filter(model.TrainingData.lang == lang) for dr in req: self.w2v_all_utterances.append( (dr.utterance, dr.module, dr.loc_fn, dr.loc_line)) if not self.w2v_model: from gensim.models import word2vec model_fn = self.config.get('semantics', 'word2vec_model_%s' % lang) logging.debug('loading word2vec model %s ...' % model_fn) logging.getLogger('gensim.models.word2vec').setLevel(logging.WARNING) self.w2v_model = word2vec.Word2Vec.load_word2vec_format(model_fn, binary=True) self.w2v_lang = lang #list containing names of words in the vocabulary self.w2v_index2word_set = set(self.w2v_model.index2word) logging.debug('loading word2vec model %s ... done' % model_fn) def align_utterances(self, lang, utterances): self.setup_align_utterances(lang) res = {} for utt1 in utterances: try: utt1t = tokenize(utt1, lang=lang) av1 = avg_feature_vector( utt1t, model=self.w2v_model, num_features=300, index2word_set=self.w2v_index2word_set) sims = {} # location -> score utts = {} # location -> utterance for utt2, module, loc_fn, loc_line in self.w2v_all_utterances: try: utt2t = tokenize(utt2, lang=lang) av2 = avg_feature_vector( utt2t, model=self.w2v_model, num_features=300, index2word_set=self.w2v_index2word_set) sim = 1 - cosine(av1, av2) location = '%s:%s:%d' % (module, loc_fn, loc_line) sims[location] = sim utts[location] = utt2 # logging.debug('%10.8f %s' % (sim, location)) except: logging.error('EXCEPTION CAUGHT %s' % traceback.format_exc()) logging.info('sims for %s' % repr(utt1)) cnt = 0 res[utt1] = [] for sim, location in sorted( ((v, k) for k, v in sims.iteritems()), reverse=True): logging.info('%10.8f %s' % (sim, location)) logging.info(' %s' % (utts[location])) res[utt1].append((sim, location, utts[location])) cnt += 1 if cnt > 5: break except: logging.error('EXCEPTION CAUGHT %s' % traceback.format_exc()) return res
class TestEmbeddings (unittest.TestCase): def setUp(self): # # db, store # db_url = 'sqlite:///foo.db' # setup compiler + environment self.db = LogicDB(db_url, echo=False) self.parser = PrologParser(self.db) self.rt = PrologRuntime(self.db) # self.rt.set_trace(True) self.db.clear_module(UNITTEST_MODULE) def tearDown(self): self.db.close() #@unittest.skip("temporarily disabled") def test_custom_builtins(self): global recorded_moves self.parser.compile_file('samples/hanoi2.pl', UNITTEST_MODULE) clause = self.parser.parse_line_clause_body('move(3,left,right,center)') logging.debug('clause: %s' % clause) # register our custom builtin recorded_moves = [] self.rt.register_builtin('record_move', record_move) solutions = self.rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual (len(solutions), 1) self.assertEqual (len(recorded_moves), 7) #@unittest.skip("temporarily disabled") def test_custom_builtin_multiple_bindings(self): self.rt.register_builtin('multi_binder', multi_binder) clause = self.parser.parse_line_clause_body('multi_binder(X,Y)') logging.debug('clause: %s' % clause) solutions = self.rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual (len(solutions), 4) def _custom_directive(self, db, module_name, clause, user_data): # logging.debug('custom_directive has been run') self.assertEqual (len(clause.head.args), 3) self.assertEqual (unicode(clause.head.args[0]), u'abc') self.assertEqual (clause.head.args[1].f, 42) self.assertEqual (clause.head.args[2].s, u'foo') self.directive_mark = True #@unittest.skip("temporarily disabled") def test_custom_directives(self): self.parser.register_directive('custom_directive', self._custom_directive, None) self.directive_mark = False # self.parser.compile_file('samples/dir.pl', UNITTEST_MODULE) clauses = self.parser.parse_line_clauses('custom_directive(abc, 42, \'foo\').') self.assertEqual (self.directive_mark, True)
class AIKernal(object): def __init__(self): self.config = misc.load_config('.airc') # # database # Session = sessionmaker(bind=model.engine) self.session = Session() # # logic DB # self.db = LogicDB(model.url) # # knowledge base # self.kb = AIKB() # # TensorFlow (deferred, as tf can take quite a bit of time to set up) # self.tf_session = None self.nlp_model = None # # module management, setup # self.modules = {} s = self.config.get('semantics', 'modules') self.all_modules = map (lambda s: s.strip(), s.split(',')) # # prolog environment setup # self.prolog_rt = AIPrologRuntime(self.db, self.kb) self.parser = AIPrologParser() # FIXME: this will work only on the first call def setup_tf_model (self, forward_only, load_model): if not self.tf_session: import tensorflow as tf # setup config to use BFC allocator config = tf.ConfigProto() config.gpu_options.allocator_type = 'BFC' self.tf_session = tf.Session(config=config) if not self.nlp_model: from nlp_model import NLPModel self.nlp_model = NLPModel(self.session) if load_model: self.nlp_model.load_dicts() # we need the inverse dict to reconstruct the output from tensor self.inv_output_dict = {v: k for k, v in self.nlp_model.output_dict.iteritems()} self.tf_model = self.nlp_model.create_tf_model(self.tf_session, forward_only = forward_only) self.tf_model.batch_size = 1 self.nlp_model.load_model(self.tf_session) def clean (self, module_names, clean_all, clean_logic, clean_discourses, clean_cronjobs, clean_kb): for module_name in module_names: if clean_logic or clean_all: logging.info('cleaning logic for %s...' % module_name) if module_name == 'all': self.db.clear_all_modules() else: self.db.clear_module(module_name) if clean_discourses or clean_all: logging.info('cleaning discourses for %s...' % module_name) if module_name == 'all': self.session.query(model.DiscourseRound).delete() else: self.session.query(model.DiscourseRound).filter(model.DiscourseRound.module==module_name).delete() if clean_cronjobs or clean_all: logging.info('cleaning cronjobs for %s...' % module_name) if module_name == 'all': self.session.query(model.Cronjob).delete() else: self.session.query(model.Cronjob).filter(model.Cronjob.module==module_name).delete() if clean_kb or clean_all: logging.info('cleaning kb for %s...' % module_name) if module_name == 'all': self.kb.clear_all_graphs() else: graph = self._module_graph_name(module_name) self.kb.clear_graph(graph) self.session.commit() def load_module (self, module_name, run_init=False, run_trace=False): if module_name in self.modules: return self.modules[module_name] logging.debug("loading module '%s'" % module_name) fp, pathname, description = imp.find_module(module_name, ['modules']) # print fp, pathname, description m = None try: m = imp.load_module(module_name, fp, pathname, description) self.modules[module_name] = m # print m # print getattr(m, '__all__', None) # for name in dir(m): # print name for m2 in getattr (m, 'DEPENDS'): self.load_module(m2, run_init=run_init, run_trace=run_trace) if hasattr(m, 'RDF_PREFIXES'): prefixes = getattr(m, 'RDF_PREFIXES') for prefix in prefixes: self.kb.register_prefix(prefix, prefixes[prefix]) if hasattr(m, 'LDF_ENDPOINTS'): endpoints = getattr(m, 'LDF_ENDPOINTS') for endpoint in endpoints: self.kb.register_endpoint(endpoint, endpoints[endpoint]) if hasattr(m, 'RDF_ALIASES'): aliases = getattr(m, 'RDF_ALIASES') for alias in aliases: self.kb.register_alias(alias, aliases[alias]) if hasattr(m, 'CRONJOBS'): # update cronjobs in db old_cronjobs = set() for cronjob in self.session.query(model.Cronjob).filter(model.Cronjob.module==module_name): old_cronjobs.add(cronjob.name) new_cronjobs = set() for name, interval, f in getattr (m, 'CRONJOBS'): logging.debug ('registering cronjob %s' %name) cj = self.session.query(model.Cronjob).filter(model.Cronjob.module==module_name, model.Cronjob.name==name).first() if not cj: cj = model.Cronjob(module=module_name, name=name, last_run=0) self.session.add(cj) cj.interval = interval new_cronjobs.add(cj.name) for cjn in old_cronjobs: if cjn in new_cronjobs: continue self.session.query(model.Cronjob).filter(model.Cronjob.module==module_name, model.Cronjob.name==cjn).delete() self.session.commit() if run_init: gn = rdflib.Graph(identifier=CONTEXT_GRAPH_NAME) self.kb.remove((CURIN, None, None, gn)) quads = [ ( CURIN, KB_PREFIX+u'user', DEFAULT_USER, gn) ] self.kb.addN_resolve(quads) prolog_s = u'init(\'%s\')' % (module_name) c = self.parser.parse_line_clause_body(prolog_s) self.prolog_rt.set_trace(run_trace) self.prolog_rt.reset_actions() solutions = self.prolog_rt.search(c) # import pdb; pdb.set_trace() actions = self.prolog_rt.get_actions() for action in actions: self.prolog_rt.execute_builtin_actions(action) except: logging.error(traceback.format_exc()) finally: # Since we may exit via an exception, close fp explicitly. if fp: fp.close() return m def _module_graph_name (self, module_name): return KB_PREFIX + module_name def import_kb (self, module_name): graph = self._module_graph_name(module_name) self.kb.register_graph(graph) # disabled to enable incremental kb updates self.kb.clear_graph(graph) m = self.modules[module_name] # import LDF first as it is incremental res_paths = [] for kb_entry in getattr (m, 'KB_SOURCES'): if not isinstance(kb_entry, basestring): res_paths.append(kb_entry) if len(res_paths)>0: logging.info('mirroring from LDF endpoints, target graph: %s ...' % graph) quads = self.kb.ldf_mirror(res_paths, graph) # now import files, if any for kb_entry in getattr (m, 'KB_SOURCES'): if isinstance(kb_entry, basestring): kb_pathname = 'modules/%s/%s' % (module_name, kb_entry) logging.info('importing %s ...' % kb_pathname) self.kb.parse_file(graph, 'n3', kb_pathname) def import_kb_multi (self, module_names): for module_name in module_names: if module_name == 'all': for mn2 in self.all_modules: self.load_module (mn2) self.import_kb (mn2) else: self.load_module (module_name) self.import_kb (module_name) self.session.commit() def compile_module (self, module_name, trace=False, print_utterances=False, warn_level=0): m = self.modules[module_name] logging.debug('parsing sources of module %s (print_utterances: %s) ...' % (module_name, print_utterances)) compiler = AIPrologParser (trace=trace, print_utterances=print_utterances, warn_level=warn_level) compiler.clear_module(module_name, self.db) for pl_fn in getattr (m, 'PL_SOURCES'): pl_pathname = 'modules/%s/%s' % (module_name, pl_fn) logging.debug(' parsing %s ...' % pl_pathname) compiler.compile_file (pl_pathname, module_name, self.db, self.kb) def compile_module_multi (self, module_names, run_trace=False, print_utterances=False, warn_level=0): for module_name in module_names: if module_name == 'all': for mn2 in self.all_modules: self.load_module (mn2) self.compile_module (mn2, run_trace, print_utterances, warn_level) else: self.load_module (module_name) self.compile_module (module_name, run_trace, print_utterances, warn_level) self.session.commit() def process_input (self, utterance, utt_lang, user_uri, test_mode=False, trace=False): """ process user input, return action(s) """ gn = rdflib.Graph(identifier=CONTEXT_GRAPH_NAME) tokens = tokenize(utterance, utt_lang) self.kb.remove((CURIN, None, None, gn)) quads = [ ( CURIN, KB_PREFIX+u'user', user_uri, gn), ( CURIN, KB_PREFIX+u'utterance', utterance, gn), ( CURIN, KB_PREFIX+u'uttLang', utt_lang, gn), ( CURIN, KB_PREFIX+u'tokens', pl_literal_to_rdf(ListLiteral(tokens), self.kb), gn) ] if test_mode: quads.append( ( CURIN, KB_PREFIX+u'currentTime', pl_literal_to_rdf(NumberLiteral(TEST_TIME), self.kb), gn ) ) else: quads.append( ( CURIN, KB_PREFIX+u'currentTime', pl_literal_to_rdf(NumberLiteral(time.time()), self.kb), gn ) ) self.kb.addN_resolve(quads) self.prolog_rt.reset_actions() if test_mode: for dr in self.db.session.query(model.DiscourseRound).filter(model.DiscourseRound.inp==utterance, model.DiscourseRound.lang==utt_lang): prolog_s = ','.join(dr.resp.split(';')) logging.info("test tokens=%s prolog_s=%s" % (repr(tokens), prolog_s) ) c = self.parser.parse_line_clause_body(prolog_s) # logging.debug( "Parse result: %s" % c) # logging.debug( "Searching for c: %s" % c ) solutions = self.prolog_rt.search(c) # if len(solutions) == 0: # raise PrologError ('nlp_test: %s no solution found.' % clause.location) # print "round %d utterances: %s" % (round_num, repr(prolog_rt.get_utterances())) return self.prolog_rt.get_actions() # FIXME: merge into process_input def process_line(self, line): self.setup_tf_model (True, True) from nlp_model import BUCKETS x = self.nlp_model.compute_x(line) logging.debug("x: %s -> %s" % (line, x)) # which bucket does it belong to? bucket_id = min([b for b in xrange(len(BUCKETS)) if BUCKETS[b][0] > len(x)]) # get a 1-element batch to feed the sentence to the model encoder_inputs, decoder_inputs, target_weights = self.tf_model.get_batch( {bucket_id: [(x, [])]}, bucket_id ) # print "encoder_inputs, decoder_inputs, target_weights", encoder_inputs, decoder_inputs, target_weights # get output logits for the sentence _, _, output_logits = self.tf_model.step(self.tf_session, encoder_inputs, decoder_inputs, target_weights, bucket_id, True) logging.debug("output_logits: %s" % repr(output_logits)) # this is a greedy decoder - outputs are just argmaxes of output_logits. outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits] # print "outputs", outputs preds = map (lambda o: self.inv_output_dict[o], outputs) logging.debug("preds: %s" % repr(preds)) prolog_s = '' for p in preds: if p[0] == '_': continue # skip _EOS if len(prolog_s)>0: prolog_s += ', ' prolog_s += p logging.debug('?- %s' % prolog_s) try: c = self.parser.parse_line_clause_body(prolog_s) logging.debug( "Parse result: %s" % c) self.prolog_rt.reset_actions() self.prolog_rt.search(c) abufs = self.prolog_rt.get_actions() # if we have multiple abufs, pick one at random if len(abufs)>0: abuf = random.choice(abufs) self.prolog_rt.execute_builtin_actions(abuf) self.db.commit() return abuf except PrologError as e: logging.error("*** ERROR: %s" % e) return None def test_module (self, module_name, trace=False): logging.info('running tests of module %s ...' % (module_name)) gn = rdflib.Graph(identifier=CONTEXT_GRAPH_NAME) for nlpt in self.db.session.query(model.NLPTest).filter(model.NLPTest.module==module_name): # import pdb; pdb.set_trace() # test setup predicate for this module self.kb.remove((CURIN, None, None, gn)) quads = [ ( CURIN, KB_PREFIX+u'user', TEST_USER, gn) ] self.kb.addN_resolve(quads) prolog_s = u'test_setup(\'%s\')' % (module_name) c = self.parser.parse_line_clause_body(prolog_s) self.prolog_rt.set_trace(trace) self.prolog_rt.reset_actions() solutions = self.prolog_rt.search(c) actions = self.prolog_rt.get_actions() for action in actions: self.prolog_rt.execute_builtin_actions(action) # extract test rounds, look up matching discourse_rounds, execute them clause = self.parser.parse_line_clause_body(nlpt.test_src) clause.location = nlpt.location logging.debug( "Parse result: %s (%s)" % (clause, clause.__class__)) args = clause.body.args lang = args[0].name round_num = 0 for ivr in args[1:]: if ivr.name != 'ivr': raise PrologError ('nlp_test: ivr predicate args expected.') test_in = '' test_out = '' test_actions = [] for e in ivr.args: if e.name == 'in': test_in = ' '.join(tokenize(e.args[0].s, lang)) elif e.name == 'out': test_out = ' '.join(tokenize(e.args[0].s, lang)) elif e.name == 'action': test_actions.append(e.args) else: raise PrologError (u'nlp_test: ivr predicate: unexpected arg: ' + unicode(e)) logging.info("nlp_test: %s round %d test_in : %s" % (clause.location, round_num, test_in) ) logging.info("nlp_test: %s round %d test_out : %s" % (clause.location, round_num, test_out) ) logging.info("nlp_test: %s round %d test_actions: %s" % (clause.location, round_num, test_actions) ) # execute all matching clauses, collect actions # FIXME: nlp_test should probably let the user specify a user action_buffers = self.process_input (test_in, lang, TEST_USER, test_mode=True, trace=trace) # check actual actions vs expected ones matching_abuf = None for abuf in action_buffers: logging.info("nlp_test: %s round %d %s" % (clause.location, round_num, repr(abuf)) ) # check utterance actual_out = u'' utt_lang = u'en' for action in abuf['actions']: p = action[0].name if p == 'say': utt_lang = unicode(action[1]) actual_out += u' ' + unicode(action[2]) if len(test_out) > 0: if len(actual_out)>0: actual_out = u' '.join(tokenize(actual_out, utt_lang)) if actual_out != test_out: logging.info("nlp_test: %s round %d UTTERANCE MISMATCH." % (clause.location, round_num)) continue # no match logging.info("nlp_test: %s round %d UTTERANCE MATCHED!" % (clause.location, round_num)) # check actions if len(test_actions)>0: # import pdb; pdb.set_trace() # print repr(test_actions) actions_matched = True for action in test_actions: for act in abuf['actions']: # print " check action match: %s vs %s" % (repr(action), repr(act)) if action == act: break if action != act: actions_matched = False break if not actions_matched: logging.info("nlp_test: %s round %d ACTIONS MISMATCH." % (clause.location, round_num)) continue logging.info("nlp_test: %s round %d ACTIONS MATCHED!" % (clause.location, round_num)) matching_abuf = abuf break if not matching_abuf: raise PrologError (u'nlp_test: %s round %d no matching abuf found.' % (clause.location, round_num)) self.prolog_rt.execute_builtin_actions(matching_abuf) round_num += 1 logging.info('running tests of module %s complete!' % (module_name)) def run_tests_multi (self, module_names, run_trace=False): for module_name in module_names: if module_name == 'all': for mn2 in self.all_modules: self.load_module (mn2, run_init=True, run_trace=run_trace) self.test_module (mn2, run_trace) else: self.load_module (module_name, run_init=True, run_trace=run_trace) self.test_module (module_name, run_trace) def run_cronjobs (self, module_name, force=False): m = self.modules[module_name] if not hasattr(m, 'CRONJOBS'): return graph = self._module_graph_name(module_name) self.kb.register_graph(graph) for name, interval, f in getattr (m, 'CRONJOBS'): cronjob = self.session.query(model.Cronjob).filter(model.Cronjob.module==module_name, model.Cronjob.name==name).first() t = time.time() next_run = cronjob.last_run + interval if force or t > next_run: logging.debug ('running cronjob %s' %name) f (self.config, self.kb, graph) cronjob.last_run = t def run_cronjobs_multi (self, module_names, force, run_trace=False): for module_name in module_names: if module_name == 'all': for mn2 in self.all_modules: self.load_module (mn2, run_init=True, run_trace=run_trace) self.run_cronjobs (mn2, force=force) else: self.load_module (module_name, run_init=True, run_trace=run_trace) self.run_cronjobs (module_name, force=force) self.session.commit() def train (self, num_steps): self.setup_tf_model (False, False) self.nlp_model.train(num_steps) def dump_utterances (self, num_utterances, dictfn): dic = None if dictfn: dic = set() with codecs.open(dictfn, 'r', 'utf8') as dictf: for line in dictf: parts = line.strip().split(';') if len(parts) != 2: continue dic.add(parts[0]) all_utterances = [] for dr in self.session.query(model.DiscourseRound): if not dic: all_utterances.append(dr.inp) else: # is at least one word not covered by our dictionary? unk = False for t in tokenize(dr.inp): if not t in dic: # print u"unknown word: %s in %s" % (t, dr.inp) unk = True break if not unk: continue all_utterances.append(dr.inp) utts = set() if num_utterances > 0: while (len(utts) < num_utterances): i = random.randrange(0, len(all_utterances)) utts.add(all_utterances[i]) else: for utt in all_utterances: utts.add(utt) for utt in utts: print utt
class TestZamiaProlog(unittest.TestCase): def setUp(self): # # db, store # db_url = 'sqlite:///foo.db' # setup compiler + environment self.db = LogicDB(db_url) self.parser = PrologParser(self.db) self.rt = PrologRuntime(self.db) self.db.clear_module(UNITTEST_MODULE) def tearDown(self): self.db.close() # @unittest.skip("temporarily disabled") def test_parser(self): error_catched = False try: clause = self.parser.parse_line_clause_body( 'say_eoa(en, "Kids are the best') logging.debug('clause: %s' % clause) except PrologError as e: error_catched = True self.assertEqual(error_catched, True) # @unittest.skip("temporarily disabled") def test_parse_line_clauses(self): line = 'time_span(TE) :- date_time_stamp(+(D, 1.0)).' tree = self.parser.parse_line_clauses(line) logging.debug(unicode(tree[0].body)) self.assertEqual(tree[0].body.name, 'date_time_stamp') self.assertEqual(tree[0].head.name, 'time_span') line = 'time_span(tomorrow, TS, TE) :- context(currentTime, T), stamp_date_time(T, date(Y, M, D, H, Mn, S, "local")), date_time_stamp(date(Y, M, +(D, 1.0), 0.0, 0.0, 0.0, "local"), TS), date_time_stamp(date(Y, M, +(D, 1.0), 23.0, 59.0, 59.0, "local"), TE).' tree = self.parser.parse_line_clauses(line) logging.debug(unicode(tree[0].body)) self.assertEqual(tree[0].head.name, 'time_span') self.assertEqual(tree[0].body.name, 'and') self.assertEqual(len(tree[0].body.args), 4) # @unittest.skip("temporarily disabled") def test_kb1(self): self.assertEqual(len(self.db.lookup('party', 0)), 0) self.parser.compile_file('samples/kb1.pl', UNITTEST_MODULE) self.assertEqual(len(self.db.lookup('party', 0)), 1) clause = self.parser.parse_line_clause_body('woman(X)') logging.debug('clause: %s' % clause) solutions = self.rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 3) clause = self.parser.parse_line_clause_body('party') logging.debug('clause: %s' % clause) solutions = self.rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 1) clause = self.parser.parse_line_clause_body('woman(fred)') logging.debug('clause: %s' % clause) solutions = self.rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 0) # @unittest.skip("temporarily disabled") def test_parse_to_string(self): line = u'time_span(c, X, Y) :- p1(c), p2(X, Y); p3(c); p4.' line2 = u'time_span(c, X, Y) :- or(and(p1(c), p2(X, Y)), p3(c), p4).' tree = self.parser.parse_line_clauses(line) logging.debug(unicode(tree[0].body)) self.assertEqual(unicode(tree[0]), line2) # @unittest.skip("temporarily disabled") def test_or(self): self.parser.compile_file('samples/or_test.pl', UNITTEST_MODULE) # self.rt.set_trace(True) solutions = self.rt.search_predicate('woman', ['X']) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 3) solutions = self.rt.search_predicate('human', ['X']) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 8) def test_or_toplevel(self): self.parser.compile_file('samples/or_test.pl', UNITTEST_MODULE) clause = self.parser.parse_line_clause_body( u'woman(mary); woman(jody)') logging.debug(u'clause: %s' % clause) solutions = self.rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 1) def test_or_bindings(self): clause = self.parser.parse_line_clause_body( u'S is "a", or(str_append(S, "b"), str_append(S, "c"))') logging.debug(u'clause: %s' % clause) solutions = self.rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 2) self.assertEqual(solutions[0]['S'].s, "ab") self.assertEqual(solutions[1]['S'].s, "ac") clause = self.parser.parse_line_clause_body(u'X is 42; X is 23') logging.debug(u'clause: %s' % clause) solutions = self.rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 2) def test_var_access(self): # set var X from python: clause = self.parser.parse_line_clause_body('Y is X*X') logging.debug('clause: %s' % clause) solutions = self.rt.search(clause, {'X': NumberLiteral(3)}) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 1) # access prolog result Y from python: self.assertEqual(solutions[0]['Y'].f, 9) def test_list_equality(self): clause = self.parser.parse_line_clause_body('[] is []') logging.debug('clause: %s' % clause) solutions = self.rt.search(clause, {}) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 1) clause = self.parser.parse_line_clause_body('[1] is []') logging.debug('clause: %s' % clause) solutions = self.rt.search(clause, {}) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 0) clause = self.parser.parse_line_clause_body('909442800.0 is []') logging.debug('clause: %s' % clause) solutions = self.rt.search(clause, {}) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 0) clause = self.parser.parse_line_clause_body('[1,2,3] = [1,2,3]') logging.debug('clause: %s' % clause) solutions = self.rt.search(clause, {}) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 1) clause = self.parser.parse_line_clause_body('[1,2,3] \\= [1,2,3,4,5]') logging.debug('clause: %s' % clause) solutions = self.rt.search(clause, {}) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 1) def test_is(self): clause = self.parser.parse_line_clause_body( 'GENDER is "blubber", GENDER is wde:Male') logging.debug('clause: %s' % clause) solutions = self.rt.search(clause, {}) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 0) # @unittest.skip("temporarily disabled") def test_list_eval(self): clause = self.parser.parse_line_clause_body( 'X is 23, Z is 42, Y is [X, U, Z].') solutions = self.rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 1) self.assertEqual(len(solutions[0]['Y'].l), 3) self.assertEqual(solutions[0]['Y'].l[0].f, 23.0) self.assertTrue(isinstance(solutions[0]['Y'].l[1], Variable)) self.assertEqual(solutions[0]['Y'].l[2].f, 42.0) def test_clauses_location(self): # this will trigger a runtime error since a(Y) is a predicate, # but format_str requires a literal arg clause = self.parser.parse_line_clause_body( 'X is format_str("%s", a(Y))') logging.debug('clause: %s' % clause) try: solutions = self.rt.search(clause, {}) self.fail("we should have seen a runtime error here") except PrologRuntimeError as e: self.assertEqual(e.location.line, 1) self.assertEqual(e.location.col, 29) def test_cut(self): self.parser.compile_file('samples/cut_test.pl', UNITTEST_MODULE) # self.rt.set_trace(True) clause = self.parser.parse_line_clause_body(u'bar(R, X)') logging.debug(u'clause: %s' % clause) solutions = self.rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 4) self.assertEqual(solutions[0]['R'].s, "one") self.assertEqual(solutions[1]['R'].s, "two") self.assertEqual(solutions[2]['R'].s, "many") self.assertEqual(solutions[3]['R'].s, "many") # @unittest.skip("temporarily disabled") def test_anon_var(self): clause = self.parser.parse_line_clause_body('_ is 23, _ is 42.') solutions = self.rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 1) def test_nonvar(self): clause = self.parser.parse_line_clause_body(u'S is "a", nonvar(S)') logging.debug(u'clause: %s' % clause) solutions = self.rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 1) clause = self.parser.parse_line_clause_body(u'nonvar(S)') logging.debug(u'clause: %s' % clause) solutions = self.rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 0) def test_unify_pseudo(self): clause = self.parser.parse_line_clause_body( u'C is foo, assertz(mem(foo, bar)), if var(C:mem|bar) then C:mem|bar := 23 endif, X := C:mem|bar' ) logging.debug(u'clause: %s' % clause) # self.rt.set_trace(True) solutions = self.rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 1) self.assertEqual(solutions[0]['X'].f, 23.0)
class TestAIProlog (unittest.TestCase): def setUp(self): config = misc.load_config('.airc') # # logic DB # self.db = LogicDB(model.url) # # knowledge base # self.kb = AIKB(UNITTEST_MODULE) for prefix in COMMON_PREFIXES: self.kb.register_prefix(prefix, COMMON_PREFIXES[prefix]) self.kb.clear_all_graphs() self.kb.parse_file (UNITTEST_CONTEXT, 'n3', 'tests/chancellors.n3') self.kb.parse_file (UNITTEST_CONTEXT, 'n3', 'tests/wev.n3') # # aiprolog environment setup # self.prolog_rt = AIPrologRuntime(self.db, self.kb) self.parser = AIPrologParser() self.prolog_rt.set_trace(True) self.db.clear_module(UNITTEST_MODULE) # @unittest.skip("temporarily disabled") def test_rdf_results(self): self.parser.compile_file('tests/chancellors_rdf.pl', UNITTEST_MODULE, self.db, self.kb) clause = self.parser.parse_line_clause_body('chancellor(X)') logging.debug('clause: %s' % clause) solutions = self.prolog_rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual (len(solutions), 2) # @unittest.skip("temporarily disabled") def test_rdf_exists(self): clause = self.parser.parse_line_clause_body("rdf ('http://www.wikidata.org/entity/Q567', 'http://www.wikidata.org/prop/direct/P21', 'http://www.wikidata.org/entity/Q6581072')") logging.debug('clause: %s' % clause) solutions = self.prolog_rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual (len(solutions), 1) # @unittest.skip("temporarily disabled") def test_rdf_optional(self): self.parser.compile_file('tests/chancellors_rdf.pl', UNITTEST_MODULE, self.db, self.kb) clause = self.parser.parse_line_clause_body("is_current_chancellor (X)") logging.debug('clause: %s' % clause) solutions = self.prolog_rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual (len(solutions), 1) # @unittest.skip("temporarily disabled") def test_rdf_filter(self): self.parser.compile_file('tests/chancellors_rdf.pl', UNITTEST_MODULE, self.db, self.kb) clause = self.parser.parse_line_clause_body("chancellor_labels (X, Y)") logging.debug('clause: %s' % clause) solutions = self.prolog_rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual (len(solutions), 2) # @unittest.skip("temporarily disabled") def test_rdf_filter_expr(self): clause = self.parser.parse_line_clause_body('rdf (X, dbp:termEnd, TE, filter(and(TE =< "1998-10-27", TE >= "1998-10-27")))') logging.debug('clause: %s' % clause) solutions = self.prolog_rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual (len(solutions), 1) clause = self.parser.parse_line_clause_body('rdf (X, dbp:termEnd, TE, filter(or(TE =< "1998-10-27", TE >= "1998-10-27")))') logging.debug('clause: %s' % clause) solutions = self.prolog_rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual (len(solutions), 2) clause = self.parser.parse_line_clause_body('rdf (X, dbp:termEnd, TE, filter(TE =< "1998-10-27", TE =< "1998-10-27", TE >= "1998-10-27"))') logging.debug('clause: %s' % clause) solutions = self.prolog_rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual (len(solutions), 1) # @unittest.skip("temporarily disabled") def test_rdf_joins(self): clause = self.parser.parse_line_clause_body(""" uriref(wde:Q61656, P), Lang is de, atom_chars(Lang, L2), date_time_stamp(date(2016,12,6,0,0,0,\'local\'), EvTS), date_time_stamp(date(2016,12,7,0,0,0,\'local\'), EvTE), rdf (distinct, WEV, ai:dt_end, DT_END, WEV, ai:dt_start, DT_START, WEV, ai:location, P, P, rdfs:label, Label, WEV, ai:temp_min, TempMin, WEV, ai:temp_max, TempMax, WEV, ai:precipitation, Precipitation, WEV, ai:clouds, Clouds, WEV, ai:icon, Icon, filter (DT_START >= isoformat(EvTS, 'local'), DT_END =< isoformat(EvTE, 'local'), lang(Label) = L2) ) """) logging.debug('clause: %s' % clause) solutions = self.prolog_rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual (len(solutions), 7) # @unittest.skip("temporarily disabled") def test_rdf_assert(self): clause = self.parser.parse_line_clause_body('rdf(aiu:Alice, X, Y).') solutions = self.prolog_rt.search(clause) self.assertEqual (len(solutions), 0) clause = self.parser.parse_line_clause_body('rdf_assert (aiu:Alice, aiup:name, "Alice Green"), eoa.') logging.debug('clause: %s' % clause) solutions = self.prolog_rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) actions = self.prolog_rt.get_actions() logging.debug('actions: %s' % repr(actions)) self.assertEqual (len(actions), 1) self.prolog_rt.execute_builtin_actions(actions[0]) clause = self.parser.parse_line_clause_body('rdf(aiu:Alice, X, Y).') solutions = self.prolog_rt.search(clause) self.assertEqual (len(solutions), 1) self.assertEqual (solutions[0]['X'].s, u'http://ai.zamia.org/kb/user/prop/name') self.assertEqual (solutions[0]['Y'].s, u'Alice Green') # @unittest.skip("temporarily disabled") def test_rdf_assert_list(self): clause = self.parser.parse_line_clause_body('rdf_assert (aiu:Alice, aiup:topic, [1, "abc", wde:42]), eoa.') logging.debug('clause: %s' % clause) solutions = self.prolog_rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) actions = self.prolog_rt.get_actions() logging.debug('actions: %s' % repr(actions)) self.assertEqual (len(actions), 1) self.prolog_rt.execute_builtin_actions(actions[0]) clause = self.parser.parse_line_clause_body('rdf(aiu:Alice, aiup:topic, Y).') logging.debug('clause: %s' % clause) solutions = self.prolog_rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual (len(solutions), 1) self.assertEqual (len(solutions[0]['Y'].l), 3) self.assertEqual (solutions[0]['Y'].l[0].f, 1.0) self.assertEqual (solutions[0]['Y'].l[1].s, u'abc') self.assertEqual (solutions[0]['Y'].l[2].name, u'wde:42')
class TestBuiltins(unittest.TestCase): def setUp(self): # # db, store # db_url = 'sqlite:///foo.db' # setup compiler + environment self.db = LogicDB(db_url) self.parser = PrologParser(self.db) self.rt = PrologRuntime(self.db) self.db.clear_module(UNITTEST_MODULE) def tearDown(self): self.db.close() # @unittest.skip("temporarily disabled") def test_hanoi1(self): self.parser.compile_file('samples/hanoi1.pl', UNITTEST_MODULE) clause = self.parser.parse_line_clause_body( 'move(3,left,right,center)') logging.debug('clause: %s' % clause) solutions = self.rt.search(clause) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 1) # @unittest.skip("temporarily disabled") def test_lists(self): clause = self.parser.parse_line_clause_body('X is []') solutions = self.rt.search(clause) self.assertEqual(len(solutions[0]['X'].l), 0) clause = self.parser.parse_line_clause_body( 'L is [1,2,3,4], X is list_sum(L), Y is list_max(L), Z is list_min(L), W is list_avg(L), V is list_len(L)' ) solutions = self.rt.search(clause) self.assertEqual(len(solutions[0]['L'].l), 4) self.assertEqual(solutions[0]['L'].l[3].f, 4.0) self.assertEqual(solutions[0]['X'].f, 10.0) self.assertEqual(solutions[0]['Y'].f, 4.0) self.assertEqual(solutions[0]['Z'].f, 1.0) self.assertEqual(solutions[0]['W'].f, 2.5) self.assertEqual(solutions[0]['V'].f, 4.0) clause = self.parser.parse_line_clause_body( 'L is [1,2,3,4], list_contains(L, 2).') solutions = self.rt.search(clause) self.assertEqual(len(solutions), 1) clause = self.parser.parse_line_clause_body( 'L is [1,2,3,4], list_contains(L, 23).') solutions = self.rt.search(clause) self.assertEqual(len(solutions), 0) clause = self.parser.parse_line_clause_body( 'X is [1,2,3,4], list_nth(1, X, E).') solutions = self.rt.search(clause) self.assertEqual(solutions[0]['E'].f, 2) clause = self.parser.parse_line_clause_body( 'X is [1,2,3,4], length(X, L).') solutions = self.rt.search(clause) self.assertEqual(solutions[0]['L'].f, 4) clause = self.parser.parse_line_clause_body( 'X is [1,2,3,4], list_slice(1, 3, X, E).') solutions = self.rt.search(clause) self.assertEqual(len(solutions[0]['E'].l), 2) self.assertEqual(solutions[0]['E'].l[0].f, 2.0) self.assertEqual(solutions[0]['E'].l[1].f, 3.0) clause = self.parser.parse_line_clause_body( 'X is [1,2,3,4], E is list_slice(1, 3, X).') solutions = self.rt.search(clause) self.assertEqual(len(solutions[0]['E'].l), 2) self.assertEqual(solutions[0]['E'].l[0].f, 2.0) self.assertEqual(solutions[0]['E'].l[1].f, 3.0) clause = self.parser.parse_line_clause_body( 'X is [1,2,3,4], list_append(X, 5).') solutions = self.rt.search(clause) self.assertEqual(len(solutions[0]['X'].l), 5) self.assertEqual(solutions[0]['X'].l[4].f, 5.0) clause = self.parser.parse_line_clause_body( 'X is ["1","2","3","4"], list_str_join("@", X, Y).') solutions = self.rt.search(clause) self.assertEqual(solutions[0]['Y'].s, "1@2@3@4") clause = self.parser.parse_line_clause_body( 'X is ["1","2","3","4"], Y is list_join("@", X).') solutions = self.rt.search(clause) self.assertEqual(solutions[0]['Y'].s, "1@2@3@4") # @unittest.skip("temporarily disabled") def test_list_findall(self): self.parser.compile_file('samples/kb1.pl', UNITTEST_MODULE) clause = self.parser.parse_line_clause_body( 'list_findall(X, woman(X), L)') solutions = self.rt.search(clause) self.assertEqual(len(solutions[0]), 1) self.assertEqual(len(solutions[0]['L'].l), 3) # @unittest.skip("temporarily disabled") def test_strings(self): clause = self.parser.parse_line_clause_body( 'X is \'bar\', S is format_str(\'test %d %s foo\', 42, X)') solutions = self.rt.search(clause) self.assertEqual(solutions[0]['S'].s, 'test 42 bar foo') clause = self.parser.parse_line_clause_body( 'X is \'foobar\', sub_string(X, 0, 2, _, Y)') solutions = self.rt.search(clause) self.assertEqual(solutions[0]['Y'].s, 'fo') clause = self.parser.parse_line_clause_body( 'atom_chars(foo, X), atom_chars(Y, "bar").') solutions = self.rt.search(clause) self.assertEqual(solutions[0]['X'].s, 'foo') self.assertEqual(solutions[0]['Y'].name, 'bar') # @unittest.skip("temporarily disabled") def test_date_time(self): clause = self.parser.parse_line_clause_body('get_time(T)') solutions = self.rt.search(clause) self.assertGreater(solutions[0]['T'].s, '2017-04-30T23:39:29.092271') clause = self.parser.parse_line_clause_body( 'date_time_stamp(date(2017,2,14,1,2,3,\'local\'), TS), stamp_date_time(TS, date(Y,M,D,H,Mn,S,\'local\'))' ) solutions = self.rt.search(clause) self.assertEqual(solutions[0]['Y'].f, 2017) self.assertEqual(solutions[0]['M'].f, 2) self.assertEqual(solutions[0]['D'].f, 14) self.assertEqual(solutions[0]['H'].f, 1) self.assertEqual(solutions[0]['Mn'].f, 2) self.assertEqual(solutions[0]['S'].f, 3) clause = self.parser.parse_line_clause_body( 'date_time_stamp(date(2017,2,14,1,2,3,\'Europe/Berlin\'), TS), day_of_the_week(TS, WD)' ) solutions = self.rt.search(clause) self.assertEqual(solutions[0]['TS'].s, '2017-02-14T01:02:03+01:00') self.assertEqual(solutions[0]['WD'].f, 2) # @unittest.skip("temporarily disabled") def test_arith(self): clause = self.parser.parse_line_clause_body('X is -23') solutions = self.rt.search(clause) self.assertEqual(solutions[0]['X'].f, -23) clause = self.parser.parse_line_clause_body('X is +42') solutions = self.rt.search(clause) self.assertEqual(solutions[0]['X'].f, 42) clause = self.parser.parse_line_clause_body('X is 19 + 23') solutions = self.rt.search(clause) self.assertEqual(solutions[0]['X'].f, 42) clause = self.parser.parse_line_clause_body('X is 61 - 19') solutions = self.rt.search(clause) self.assertEqual(solutions[0]['X'].f, 42) clause = self.parser.parse_line_clause_body('X is 6*7') solutions = self.rt.search(clause) self.assertEqual(solutions[0]['X'].f, 42) clause = self.parser.parse_line_clause_body('X is 1764 / 42') solutions = self.rt.search(clause) self.assertEqual(solutions[0]['X'].f, 42) clause = self.parser.parse_line_clause_body('X is 85 mod 43') solutions = self.rt.search(clause) self.assertEqual(solutions[0]['X'].f, 42) clause = self.parser.parse_line_clause_body( 'X is 23, increment(X, 19)') solutions = self.rt.search(clause) self.assertEqual(solutions[0]['X'].f, 42) clause = self.parser.parse_line_clause_body( 'X is 42, decrement(X, 19)') solutions = self.rt.search(clause) self.assertEqual(solutions[0]['X'].f, 23) # @unittest.skip("temporarily disabled") def test_comp(self): clause = self.parser.parse_line_clause_body('3>1') solutions = self.rt.search(clause) self.assertEqual(len(solutions), 1) clause = self.parser.parse_line_clause_body('1>1') solutions = self.rt.search(clause) self.assertEqual(len(solutions), 0) clause = self.parser.parse_line_clause_body('3<1') solutions = self.rt.search(clause) self.assertEqual(len(solutions), 0) clause = self.parser.parse_line_clause_body('1<1') solutions = self.rt.search(clause) self.assertEqual(len(solutions), 0) clause = self.parser.parse_line_clause_body('3=<1') solutions = self.rt.search(clause) self.assertEqual(len(solutions), 0) clause = self.parser.parse_line_clause_body('1=<1') solutions = self.rt.search(clause) self.assertEqual(len(solutions), 1) clause = self.parser.parse_line_clause_body('3>=1') solutions = self.rt.search(clause) self.assertEqual(len(solutions), 1) clause = self.parser.parse_line_clause_body('1>=1') solutions = self.rt.search(clause) self.assertEqual(len(solutions), 1) clause = self.parser.parse_line_clause_body('3\\=1') solutions = self.rt.search(clause) self.assertEqual(len(solutions), 1) clause = self.parser.parse_line_clause_body('1\\=1') solutions = self.rt.search(clause) self.assertEqual(len(solutions), 0) clause = self.parser.parse_line_clause_body('3=1') solutions = self.rt.search(clause) self.assertEqual(len(solutions), 0) clause = self.parser.parse_line_clause_body('1=1') solutions = self.rt.search(clause) self.assertEqual(len(solutions), 1) # @unittest.skip("temporarily disabled") def test_between(self): clause = self.parser.parse_line_clause_body('between(1,100,42)') solutions = self.rt.search(clause) self.assertEqual(len(solutions), 1) clause = self.parser.parse_line_clause_body('between(1,100,X)') solutions = self.rt.search(clause) self.assertEqual(len(solutions), 100) # @unittest.skip("temporarily disabled") def test_dicts(self): clause = self.parser.parse_line_clause_body( 'dict_put(U, foo, 42), X is U, dict_put(X, bar, 23), dict_get(X, Y, Z), dict_get(X, foo, V)' ) solutions = self.rt.search(clause) self.assertEqual(len(solutions[0]['U'].d), 1) self.assertEqual(len(solutions[0]['X'].d), 2) self.assertEqual(solutions[0]['Z'].f, 42) self.assertEqual(solutions[0]['V'].f, 42) self.assertEqual(solutions[1]['Z'].f, 23) self.assertEqual(solutions[1]['V'].f, 42) logging.debug(repr(solutions)) # @unittest.skip("temporarily disabled") def test_assertz(self): clause = self.parser.parse_line_clause_body( 'I is ias00001, assertz(frame (I, qIsFamiliar)), frame (ias00001, X)' ) solutions = self.rt.search(clause) logging.debug(repr(solutions)) self.assertEqual(len(solutions), 1) self.assertEqual(solutions[0]['X'].name, 'qIsFamiliar') # @unittest.skip("temporarily disabled") def test_retract(self): clause = self.parser.parse_line_clause_body( 'I is ias1, assertz(frame (I, a, x)), retract(frame (I, _, _)), assertz(frame (I, a, y)), frame(ias1, a, X)' ) solutions = self.rt.search(clause) logging.debug(repr(solutions)) self.assertEqual(len(solutions), 1) self.assertEqual(solutions[0]['X'].name, 'y') # @unittest.skip("temporarily disabled") def test_retract_db(self): clause = self.parser.parse_line_clause_body( 'I is ias1, assertz(frame (I, a, x))') solutions = self.rt.search(clause) self.assertEqual(len(solutions), 1) clause = self.parser.parse_line_clause_body('frame(ias1, a, X)') s2s = self.rt.search(clause) self.assertEqual(len(s2s), 0) self.rt.apply_overlay(UNITTEST_MODULE, solutions[0]) clause = self.parser.parse_line_clause_body('frame(ias1, a, X)') s2s = self.rt.search(clause) self.assertEqual(len(s2s), 1) self.assertEqual(s2s[0]['X'].name, 'x') clause = self.parser.parse_line_clause_body( 'retract(frame (ias1, _, _)), frame(ias1, a, X)') s2s = self.rt.search(clause) self.assertEqual(len(s2s), 0) clause = self.parser.parse_line_clause_body( 'retract(frame (ias1, _, _))') solutions = self.rt.search(clause) self.rt.apply_overlay(UNITTEST_MODULE, solutions[0]) clause = self.parser.parse_line_clause_body('frame(ias1, a, X)') s2s = self.rt.search(clause) self.assertEqual(len(s2s), 0) # @unittest.skip("temporarily disabled") def test_setz(self): clause = self.parser.parse_line_clause_body( 'assertz(frame (ias1, a, x)), assertz(frame (ias1, a, y)), setz(frame (ias1, a, _), z), frame (ias1, a, X)' ) solutions = self.rt.search(clause) logging.debug(repr(solutions)) self.assertEqual(len(solutions), 1) self.assertEqual(solutions[0]['X'].name, 'z') # @unittest.skip("temporarily disabled") def test_setz_multi(self): # self.rt.set_trace(True) clause = self.parser.parse_line_clause_body( 'I is ias2, setz(ias (I, a, _), a), setz(ias (I, b, _), b), setz(ias (I, c, _), c), ias(I, X, Y). ' ) solutions = self.rt.search(clause) logging.debug(repr(solutions)) self.assertEqual(len(solutions), 3) self.assertEqual(solutions[0]['X'].name, 'a') self.assertEqual(solutions[0]['Y'].name, 'a') self.assertEqual(solutions[1]['X'].name, 'b') self.assertEqual(solutions[1]['Y'].name, 'b') self.assertEqual(solutions[2]['X'].name, 'c') self.assertEqual(solutions[2]['Y'].name, 'c') # @unittest.skip("temporarily disabled") def test_gensym(self): logging.debug('test_gensym...') clause = self.parser.parse_line_clause_body( 'gensym(foo, I), gensym(foo, J)') solutions = self.rt.search(clause) logging.debug(repr(solutions)) self.assertEqual(len(solutions), 1) self.assertNotEqual(solutions[0]['I'].name, solutions[0]['J'].name) logging.debug('test_gensym... done.') # @unittest.skip("temporarily disabled") def test_sets(self): clause = self.parser.parse_line_clause_body( 'set_add(S, 42), set_add(S, 23), set_add(S, 23), set_get(S, V)') solutions = self.rt.search(clause) self.assertEqual(len(solutions[0]), 2) self.assertEqual(len(solutions[0]['S'].s), 2) logging.debug(repr(solutions)) # @unittest.skip("temporarily disabled") def test_set_findall(self): self.parser.compile_file('samples/kb1.pl', UNITTEST_MODULE) clause = self.parser.parse_line_clause_body( 'set_findall(X, woman(X), S)') solutions = self.rt.search(clause) self.assertEqual(len(solutions[0]), 1) self.assertEqual(len(solutions[0]['S'].s), 3) # @unittest.skip("temporarily disabled") def test_eval_functions(self): clause = self.parser.parse_line_clause_body( 'X is [23, 42], Y is [list_avg(X), list_sum(Z)]') solutions = self.rt.search(clause) logging.debug(repr(solutions)) self.assertEqual(len(solutions), 1) self.assertEqual(len(solutions[0]['Y'].l), 2) # @unittest.skip("temporarily disabled") def test_set(self): clause = self.parser.parse_line_clause_body( 'set(X, 23), set(X, 42), set(Y, 23), Z := Y * 2') solutions = self.rt.search(clause) logging.debug(repr(solutions)) self.assertEqual(len(solutions), 1) self.assertEqual(solutions[0]['X'].f, 42) self.assertEqual(solutions[0]['Y'].f, 23) self.assertEqual(solutions[0]['Z'].f, 46) # @unittest.skip("temporarily disabled") def test_set_pseudo(self): clause = self.parser.parse_line_clause_body( 'assertz(foo(bar, 23)), set(bar:foo, 42), foo(bar, X), Z := bar:foo' ) # self.rt.set_trace(True) solutions = self.rt.search(clause) logging.debug(repr(solutions)) self.assertEqual(len(solutions), 1) self.assertEqual(solutions[0]['X'].f, 42)
class TestNegation(unittest.TestCase): def setUp(self): # # db, store # db_url = 'sqlite:///foo.db' # setup compiler + environment self.db = LogicDB(db_url) self.parser = PrologParser(self.db) self.rt = PrologRuntime(self.db) self.db.clear_module(UNITTEST_MODULE) self.rt.set_trace(True) def tearDown(self): self.db.close() # @unittest.skip("temporarily disabled") def test_not_succ(self): clause = self.parser.parse_line_clause_body( 'X is 1, Y is 2, not(X is Y).') logging.debug('clause: %s' % clause) solutions = self.rt.search(clause, {}) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 1) # @unittest.skip("temporarily disabled") def test_not_fail(self): clause = self.parser.parse_line_clause_body( 'X is 2, Y is 2, not(X is Y).') logging.debug('clause: %s' % clause) solutions = self.rt.search(clause, {}) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 0) # @unittest.skip("temporarily disabled") def test_chancellors(self): self.parser.compile_file('samples/not_test.pl', UNITTEST_MODULE) clause = self.parser.parse_line_clause_body( 'was_chancellor(helmut_kohl).') logging.debug('clause: %s' % clause) solutions = self.rt.search(clause, {}) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 1) # @unittest.skip("temporarily disabled") def test_double_negation(self): self.parser.compile_file('samples/not_test.pl', UNITTEST_MODULE) clause = self.parser.parse_line_clause_body( 'not(not(chancellor(helmut_kohl))).') logging.debug('clause: %s' % clause) solutions = self.rt.search(clause, {}) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 1) clause = self.parser.parse_line_clause_body( 'not(not(chancellor(angela_merkel))).') logging.debug('clause: %s' % clause) solutions = self.rt.search(clause, {}) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 1) clause = self.parser.parse_line_clause_body('not(not(chancellor(X))).') logging.debug('clause: %s' % clause) solutions = self.rt.search(clause, {}) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 2) # @unittest.skip("temporarily disabled") def test_assertz_negation(self): clause = self.parser.parse_line_clause_body( 'assertz(foobar(a)), foobar(a), (not(foobar(b))).') logging.debug('clause: %s' % clause) solutions = self.rt.search(clause, {}) logging.debug('solutions: %s' % repr(solutions)) self.assertEqual(len(solutions), 1)