def prepare_ruler(args): rules_tsv_path = args.rules_tsv url = args.url username = args.username password = args.password split_dir_path = args.split_dir ruler_pkl_path = args.ruler_pkl min_conf = args.min_conf min_supp = args.min_supp overwrite = args.overwrite # # Check that (input) POWER Rules TSV exists # logging.info('Check that (input) POWER Rules TSV exists ...') rules_tsv = RulesTsv(Path(rules_tsv_path)) rules_tsv.check() # # Check that (input) POWER Split Directory exists # logging.info('Check that (input) POWERT Split Directory exists ...') split_dir = SplitDir(Path(split_dir_path)) split_dir.check() # # Check that (output) POWER Ruler PKL does not exist # logging.info('Check that (output) POWER Ruler PKL does not exist ...') ruler_pkl = RulerPkl(Path(ruler_pkl_path)) ruler_pkl.check(should_exist=overwrite) # # Read rules # logging.info('Read rules ...') ent_to_lbl = split_dir.entities_tsv.load() rel_to_lbl = split_dir.relations_tsv.load() anyburl_rules = rules_tsv.load() rules = [Rule.from_anyburl(rule, ent_to_lbl, rel_to_lbl) for rule in anyburl_rules] good_rules = [rule for rule in rules if rule.conf >= min_conf and rule.fires >= min_supp] good_rules.sort(key=lambda rule: rule.conf, reverse=True) short_rules = [rule for rule in good_rules if len(rule.body) == 1] log_rules('Rules', short_rules) # # Load train facts # logging.info('Load train facts ...') train_triples = split_dir.train_facts_tsv.load() train_facts = {Fact.from_ints(head, rel, tail, ent_to_lbl, rel_to_lbl) for head, _, rel, _, tail, _ in train_triples} # # Process rules # logging.info('Process rules ...') driver = GraphDatabase.driver(url, auth=(username, password)) unsupported_rules = 0 pred = defaultdict(get_defaultdict) with driver.session() as session: if logging.getLogger().level == logging.DEBUG: iter_short_rules = short_rules else: iter_short_rules = tqdm(short_rules) for rule in iter_short_rules: logging.debug(f'Process rule {rule}') # # Process rule body # body_fact = rule.body[0] if type(body_fact.head) == Var and type(body_fact.tail) == Ent: records = session.write_transaction(query_facts_by_rel_tail, rel=body_fact.rel, tail=body_fact.tail) ents = [Ent(head['id'], ent_to_lbl[head['id']]) for head, _, _ in records] elif type(body_fact.head) == Ent and type(body_fact.tail) == Var: records = session.write_transaction(query_facts_by_head_rel, head=body_fact.head, rel=body_fact.rel) ents = [Ent(tail['id'], ent_to_lbl[tail['id']]) for _, _, tail in records] else: logging.debug(f'Unsupported rule body in rule {rule}. Skipping.') unsupported_rules += 1 continue # # Process rule head # head_fact = rule.head if type(head_fact.head) == Var and type(head_fact.tail) == Ent: pred_facts = [Fact(ent, head_fact.rel, head_fact.tail) for ent in ents] elif type(head_fact.head) == Ent and type(head_fact.tail) == Var: pred_facts = [Fact(head_fact.head, head_fact.rel, ent) for ent in ents] else: logging.debug(f'Unsupported rule head in rule {rule}. Skipping.') unsupported_rules += 1 continue # # Filter out train facts and save predicted valid facts # for fact in pred_facts: # if fact not in train_facts: pred[fact.head][(fact.rel, fact.tail)].append(rule) driver.close() # # Persist ruler # logging.info('Persist ruler ...') ruler = Ruler() ruler.pred = pred ruler_pkl.save(ruler)