Example #1
0
class SampleKB:
    def __init__(self, num_relations, num_entities,
                 arities=[0.0, 1.0, 0.0],
                 fb_densities=[0.0, 0.0, 0.0],
                 arg_densities=[0., 0.1, 0.0],
                 fact_prob=0.2,
                 num_symm=2,
                 num_impl=[0, 2, 0],
                 num_impl_inv=2,
                 num_impl_conj=[0, 2, 0],
                 num_trans_single=2,
                 num_trans_diff=2,
                 seed=0,
                 position_dependent_args=False,
                 position_densities=[0., 0.5, 0.0]):
        """
        :param num_relations:
        :param num_entities: number of distinct entities to generate
        :param arities:  fraction of arities
        :param arg_densities: fraction of entity combinations that are observed
        :param fact_prob:
        :param num_inv: number of 'inv' formulae   R(X0, X1) :- R(X1, X0)
        :param num_impl:
        :param num_impl_conj:
        :param num_trans:
        :param negated_head_prob:
        :param seed:
        :return:
        """
        random.seed(seed)
        self.kb = KB(seed=seed)

        num_relations_per_arity = [int(x * num_relations) for x in arities]

        entities = list(map(lambda x: "e" + str(x), range(1, num_entities+1)))

        entities_arg1 = []
        entities_arg2 = []
        entities_arg3 = []

        if position_dependent_args:
            arg1_boundary = int(len(entities)*position_densities[0])
            arg2_boundary = arg1_boundary + int(len(entities)*position_densities[1])
            entities_arg1 = entities[0:arg1_boundary]
            entities_arg2 = entities[arg1_boundary:arg2_boundary]
            entities_arg3 = entities[arg2_boundary:]
        else:
            entities_arg1 = entities
            entities_arg2 = entities
            entities_arg3 = entities

        pairs = [(x, y) for x in entities_arg1
                 for y in entities_arg2 if not x == y]

        triples = [(x, y, z) for x in entities_arg1
                    for y in entities_arg2 for z in entities_arg3
                    if not x == y and not y == z and not z == x]

        num_pair_samples = min(len(pairs), int(len(entities_arg1) *
                                               len(entities_arg2) *
                                               arg_densities[1]))
        num_triple_samples = min(len(triples), int(len(entities_arg1) *
                                                   len(entities_arg2) *
                                                   len(entities_arg3) *
                                                   arg_densities[2]))
        entities_per_arity = {
            1: entities_arg1,
            2: random.sample(pairs, num_pair_samples),
            3: random.sample(triples, num_triple_samples)
        }

        relations_per_arity = {}
        for arity in range(1, len(num_relations_per_arity) + 1):
            for i in range(1, num_relations_per_arity[arity - 1] + 1):
                fb_prefix = ""
                if fb_densities[arity-1] > random.uniform(0, 1.0):
                    fb_prefix = "REL$"
                if arity == 1:
                    rel = fb_prefix+"u"
                elif arity == 2:
                    rel = fb_prefix+"b"
                else:
                    rel = fb_prefix+"t"
                rel += str(i)

                if not arity in relations_per_arity:
                    relations_per_arity[arity] = list()
                relations_per_arity[arity].append(rel)

                for args in random.sample(entities_per_arity[arity],
                                          int(len(entities_per_arity[arity]) * fact_prob)):
                    self.kb.add_train(rel, args)

        inverse = []
        # sample symmetric relations r(X,Y) => r(Y,X)
        if 2 in relations_per_arity:
            symm = random.sample([(x, x) for x in relations_per_arity[2]], num_symm)
            inverse += symm

        # sampling implication, reversed: r1(X,Y) => r2(Y,X)
        if 2 in relations_per_arity:
            inverse += random.sample([(x, y) for x in relations_per_arity[2]
                                     for y in relations_per_arity[2]
                                     if not x == y], num_impl_inv)
        if len(inverse) > 0:
            self.kb.add_formulae("inv", {2: inverse})

        # sampling implications:
        # r1(X) => r2(X)
        # r1(X,Y) => r2(X,Y)
        implications_per_arity = {}
        for arity in range(1, len(num_relations_per_arity) + 1):
            if arity in relations_per_arity:
                implications_per_arity[arity] = \
                    random.sample([(x, y) for x in relations_per_arity[arity] for y in relations_per_arity[arity]
                                   if not x == y], num_impl[arity - 1])
        self.kb.add_formulae("impl", implications_per_arity)

        # sampling implications with conjunction in body:
        # r1(X,Y) ^ r2(X,Y) => r3(X,Y)
        # r1(X) ^ r2(X) => r3(X)
        implications_with_conjunction_per_arity = {}
        for arity in range(1, len(num_relations_per_arity) + 1):
            if arity in relations_per_arity and len(relations_per_arity[arity]) >= 3:
                implications_with_conjunction_per_arity[arity] = \
                    random.sample([(x, y, z) for x in relations_per_arity[arity]
                                   for y in relations_per_arity[arity]
                                   for z in relations_per_arity[arity]
                                   if not x == y and not y == z and not z == x],
                                  num_impl_conj[arity - 1])
        self.kb.add_formulae("impl_conj", implications_with_conjunction_per_arity)

        # sampling transitivities:
        transitivities = []
        # (1) simple transitivities  r(X,Y) ^ r(Y,Z) => r(X,Z)
        # (2) general transitivities  r1(X,Y) ^ r2(Y,Z) => r3(X,Z)  (r1, r2, r3 differ)

        if 2 in relations_per_arity:
            if num_trans_single > 0:
                transitivities += random.sample([(x, x, x)
                                                for x in relations_per_arity[2]], num_trans_single)
            if num_trans_diff > 0:
                transitivities += random.sample([(x, y, z)
                                                for x in relations_per_arity[2]
                                                for y in relations_per_arity[2]
                                                for z in relations_per_arity[2]
                                                if not x == y and
                                                not y == z and
                                                not z == x], num_trans_diff)
        if len(transitivities) > 0:
            self.kb.add_formulae("trans", {2: transitivities})

        # todo: sampling negation (also applies to all heads of formulae above):
        # r1 => !r2

    def get_kb(self):
        return self.kb