コード例 #1
0
def prepare_ruler(args):
    rules_tsv_path = args.rules_tsv
    url = args.url
    username = args.username
    password = args.password
    split_dir_path = args.split_dir
    ruler_pkl_path = args.ruler_pkl

    min_conf = args.min_conf
    min_supp = args.min_supp
    overwrite = args.overwrite

    #
    # Check that (input) POWER Rules TSV exists
    #

    logging.info('Check that (input) POWER Rules TSV exists ...')

    rules_tsv = RulesTsv(Path(rules_tsv_path))
    rules_tsv.check()

    #
    # Check that (input) POWER Split Directory exists
    #

    logging.info('Check that (input) POWERT Split Directory exists ...')

    split_dir = SplitDir(Path(split_dir_path))
    split_dir.check()

    #
    # Check that (output) POWER Ruler PKL does not exist
    #

    logging.info('Check that (output) POWER Ruler PKL does not exist ...')

    ruler_pkl = RulerPkl(Path(ruler_pkl_path))
    ruler_pkl.check(should_exist=overwrite)

    #
    # Read rules
    #

    logging.info('Read rules ...')

    ent_to_lbl = split_dir.entities_tsv.load()
    rel_to_lbl = split_dir.relations_tsv.load()

    anyburl_rules = rules_tsv.load()
    rules = [Rule.from_anyburl(rule, ent_to_lbl, rel_to_lbl) for rule in anyburl_rules]

    good_rules = [rule for rule in rules if rule.conf >= min_conf and rule.fires >= min_supp]
    good_rules.sort(key=lambda rule: rule.conf, reverse=True)

    short_rules = [rule for rule in good_rules if len(rule.body) == 1]
    log_rules('Rules', short_rules)

    #
    # Load train facts
    #

    logging.info('Load train facts ...')

    train_triples = split_dir.train_facts_tsv.load()
    train_facts = {Fact.from_ints(head, rel, tail, ent_to_lbl, rel_to_lbl)
                   for head, _, rel, _, tail, _ in train_triples}

    #
    # Process rules
    #

    logging.info('Process rules ...')

    driver = GraphDatabase.driver(url, auth=(username, password))
    unsupported_rules = 0

    pred = defaultdict(get_defaultdict)

    with driver.session() as session:
        if logging.getLogger().level == logging.DEBUG:
            iter_short_rules = short_rules
        else:
            iter_short_rules = tqdm(short_rules)

        for rule in iter_short_rules:
            logging.debug(f'Process rule {rule}')

            #
            # Process rule body
            #

            body_fact = rule.body[0]

            if type(body_fact.head) == Var and type(body_fact.tail) == Ent:
                records = session.write_transaction(query_facts_by_rel_tail, rel=body_fact.rel, tail=body_fact.tail)
                ents = [Ent(head['id'], ent_to_lbl[head['id']]) for head, _, _ in records]

            elif type(body_fact.head) == Ent and type(body_fact.tail) == Var:
                records = session.write_transaction(query_facts_by_head_rel, head=body_fact.head, rel=body_fact.rel)
                ents = [Ent(tail['id'], ent_to_lbl[tail['id']]) for _, _, tail in records]

            else:
                logging.debug(f'Unsupported rule body in rule {rule}. Skipping.')
                unsupported_rules += 1
                continue

            #
            # Process rule head
            #

            head_fact = rule.head

            if type(head_fact.head) == Var and type(head_fact.tail) == Ent:
                pred_facts = [Fact(ent, head_fact.rel, head_fact.tail) for ent in ents]

            elif type(head_fact.head) == Ent and type(head_fact.tail) == Var:
                pred_facts = [Fact(head_fact.head, head_fact.rel, ent) for ent in ents]

            else:
                logging.debug(f'Unsupported rule head in rule {rule}. Skipping.')
                unsupported_rules += 1
                continue

            #
            # Filter out train facts and save predicted valid facts
            #

            for fact in pred_facts:
                # if fact not in train_facts:
                pred[fact.head][(fact.rel, fact.tail)].append(rule)

    driver.close()

    #
    # Persist ruler
    #

    logging.info('Persist ruler ...')

    ruler = Ruler()
    ruler.pred = pred

    ruler_pkl.save(ruler)