Esempio n. 1
0
def plot_test_distances(dataset: str):
    limit = 10
    train, valid, test = load_dataset(dataset)
    # map entities to an id
    emap = IMap()
    for h, _, t in train:
        emap.put(h)
        emap.put(t)
    # build the kg
    kg = lil_matrix((len(emap), len(emap)), dtype=np.uint16)
    for h, _, t in train:
        kg[emap[h], emap[t]] = 1
    kg = kg.tocsr()
    test.sort(key=lambda hrt: hrt[0])
    distances = []
    _h = None
    shortest = None
    for h, _, t in tqdm(test, desc="Distances"):
        if _h != h:
            shortest = dijkstra(kg,
                                limit=limit,
                                indices=emap[h],
                                return_predecessors=False)
            _h = h
        distances.append(shortest[emap[t]])
    distances = np.array(distances)
    distances[distances > limit] = limit + 1
    plt.hist(distances, bins=range(0, limit + 2))
    plt.axvline(distances.mean(), color="red", linestyle="dashed")
    plt.axvline(np.median(distances), color="black")
    plt.title(f"Distances of test triples in training graph in {dataset}")
    plt.xlabel("distance")
    plt.ylabel("# of nodes")
    plt.show()
Esempio n. 2
0
def targets(data, dataset: str, min_dist=2, max_dist=3):
	try:
		with open(f"Structures/bad_ex_{dataset}.json", "r") as f:
			return json.load(f)
	except FileNotFoundError:
		emap = IMap()
		r_t = defaultdict(set)
		h_r = defaultdict(set)
		for h, r, t in data:
			emap.put(h)
			emap.put(t)
			r_t[r].add(emap[t])
			h_r[emap[h]].add(r)
		g = lil_matrix((len(emap), len(emap)))
		for h, r, t in data:
			g[emap[h], emap[t]] = 1
		g = g.tocsr()
		ts = []
		for i in trange(len(emap), desc="Bad examples", ncols=140):
			rel_inds = set()
			for r in h_r[i]:
				rel_inds |= r_t[r]
			dists = dijkstra(
				g, directed=False, unweighted=True, indices=i,
				return_predecessors=False, limit=max_dist
			)
			dists_inds = set(np.asarray((min_dist <= dists) & (dists <= max_dist)).nonzero()[0].tolist())
			ts.append(list(dists_inds ^ rel_inds))
		with open(f"Structures/bad_ex_{dataset}.json", "w") as f:
			json.dump(ts, f)
		return ts
Esempio n. 3
0
def sparse_graph(data):
	emap = IMap()
	for h, r, t in data:
		emap.put(h)
		emap.put(t)
	g = lil_matrix((len(emap), len(emap)))
	for h, r, t in data:
		g[emap[h], emap[t]] = 1
	return g.tocsr()
Esempio n. 4
0
class Graph3D:
    def __init__(self):
        # a mapping between english names and entity number
        self.emap = IMap()
        # a mapping between english names and relation number
        self.rmap = IMap()
        # all known r -> t
        self.r_t = defaultdict(set)
        # the Knowledge Graph
        self.kg = defaultdict(list)

    def __sizeof__(self):
        return (self.emap.__sizeof__() + self.rmap.__sizeof__() +
                self.kg.__sizeof__() + sum(edges.__sizeof__()
                                           for edges in self.kg.values()))

    def __getitem__(self, item):
        if isinstance(item, int):
            return self.kg[item]
        return self.kg[self.emap[item]]

    def __iter__(self):
        return iter(self.emap.keys())

    def __len__(self):
        return len(self.emap)

    def __contains__(self, item):
        return item in self.emap and self.emap[item] in self.kg

    def relations(self, entity):
        return [(entity, r, t) for r, t in self[entity]]

    def add(self, *triplets):
        for h, r, t in triplets:
            self.emap.put(h)
            self.rmap.put(r)
            self.emap.put(t)
            self.r_t[self.rmap[r]].add(self.emap[t])
            self.kg[self.emap[h]].append((self.rmap[r], self.emap[t]))
            self.kg[self.emap[t]].append((-self.rmap[r], self.emap[h]))

    def inspect(self, entity):
        return (("-> " + self.rmap.rget(r), self.emap.rget(t)) if r >= 0 else
                ("<- " + self.rmap.rget(-r), self.emap.rget(t))
                for r, t in self.kg[self.emap[entity]])

    def out(self, entity):
        for r, t in self[entity]:
            if r > 0:
                yield self.rmap.rget(r), self.emap.rget(t)

    def to_relfirst_dict(self):
        res = {}
        for h in self.emap.values():
            res[h] = defaultdict(list)
            for r, t in self[h]:
                res[h][r].append(t)
            res[h] = dict(res[h])
        return res

    def random_walks_r(self, h, r, max_depth, neigh_ratio=0.5):
        """
		Random walks from h that only ends in known t for r
		"""
        assert max_depth >= 2
        # <path, <end_node, count>>
        paths = defaultdict(lambda: defaultdict(float))
        candidates = self.r_t[self.rmap[r]]
        # if the node is unknown, return empty paths
        if h not in self or not candidates:
            return paths
        # dynamically choose a value for n
        neigh_size = max(int(len(self[h]) * neigh_ratio), 5)
        # for each depth, amount of path
        for depth, amount in depth_amount(max_depth, neigh_size):
            # walk a random path
            for i in range(amount):
                path = []
                node = h
                # walk a random path with random depth (up to depth)
                for d in range(depth):
                    _r, neigh = choice(self[node])
                    path.append(_r)
                    node = neigh
                if node in candidates:
                    # pad every paths with 0 (for batching purposes)
                    path += [0] * (max_depth - len(path))
                    assert len(path) == max_depth
                    # add 1 for the end node in the path distribution
                    paths[tuple(path)][node] += 1

        # normalize path distributions
        for path in paths:
            total = sum(paths[path].values())
            for node in paths[path]:
                paths[path][node] = paths[path][node] / total
        return paths

    def random_paths(self, h, targets, max_depth, neigh_ratio=0.5):
        assert max_depth >= 2
        # {targets: paths}
        paths = defaultdict(set)
        # if the node is unknown, return empty paths
        if h not in self:
            return paths
        h = self.emap[h]
        # accelerate the search for paths to t
        neighs = {n: (-r, t) for t in targets for r, n in self[t]}
        # dynamically choose a value for n
        neigh_size = max(int(len(self[h]) * neigh_ratio), 5)
        # for each depth, amount of path
        for depth, amount in depth_amount(max_depth, neigh_size):
            # walk a random path
            for i in range(amount):
                path = []
                node = h
                # walk a random path up to depth-1
                for d in range(depth - 1):
                    _r, neigh = choice(self[node])
                    path.append(self.rmap.rget(_r))
                    node = neigh
                if node in neighs:
                    _r, _t = neighs[node]
                    path.append(self.rmap.rget(_r))
                    # pad every paths with 0 (for batching purposes)
                    path += [''] * (max_depth - len(path))
                    assert len(path) == max_depth
                    # add the path
                    paths[_t].add(tuple(path))
        return defaultdict(list,
                           {t: list(paths)
                            for t, paths in paths.items()})

    def browse(self):
        node = list(self.emap.keys())[1]

        def completer(text, state):
            options = [
                self.emap.rget(_t) for _, _t in self[node]
                if self.emap.rget(_t).startswith(text)
            ]
            return options[state] if state < len(options) else None

        readline.parse_and_bind("tab: complete")
        readline.set_completer(completer)

        print(f"{node} |{len(self[node])}|:")
        for r, t in self.inspect(node):
            print(r, t)

        while True:
            print("Type stop to stop or or <entity> to explore the entity")
            inpt = input("")
            print()
            if inpt.lower() in {"stop", "quit", "exit"}:
                break
            if inpt in self:
                node = inpt
                print(f"{node} |{len(self[node])}|:")
                for r, t in self.inspect(node):
                    print(r, t)
            else:
                print(f"{inpt} is not known in the knowledge graph")
Esempio n. 5
0
class TransE(KG):
    """
	A reimplementation of TransE which learns embedding for h and r such that h + r ≈ t
	"""
    module: TransEModule
    optimizer: optim
    emap: IMap
    rmap: IMap
    h2t: dict
    device: torch.device

    def __init__(self, lr=0.005, margin=45, dim=50):
        self.path = "Models/TransE/save.pt"
        self.batch_size = 128
        if torch.cuda.is_available():
            print("Using the GPU")
            self.device = torch.device("cuda")
        else:
            print("Using the CPU")
            self.device = torch.device("cpu")
        # hyperparameters
        self.lr = lr
        self.margin = margin
        self.dim = dim
        self.limit = 7

    def inspect_embeddings(self):
        e_avg = self.module.e_embed.weight.mean(dim=0)
        e_var = (e_avg - self.module.e_embed.weight).norm(dim=1).mean()
        print(
            f"E avg norm {e_avg.norm():.2f}, E var {e_var:.2f}, "
            f"R norm avg {self.module.r_embed.weight.norm(dim=1).mean():.2f}")
        plot_entity(self.module.e_embed.weight.cpu().detach().numpy(),
                    self.emap)

    def epoch(self, it, learn=True):
        roll_loss = deque(maxlen=50 if learn else None)
        roll_pd = deque(maxlen=50 if learn else None)
        roll_nd = deque(maxlen=50 if learn else None)
        for pos_triples, neg_triples in it:
            pos_triples = torch.stack(pos_triples).to(torch.long).to(
                self.device)
            neg_triples = torch.stack(neg_triples).to(torch.long).to(
                self.device)
            self.optimizer.zero_grad()
            # feed the head and the relation
            pos_dist, neg_dist = self.module(pos_triples, neg_triples)
            loss = self.criterion(pos_dist, neg_dist)
            roll_pd.append(pos_dist.mean())
            roll_nd.append(neg_dist.mean())
            # learn
            if learn:
                loss.backward()
                self.optimizer.step()
            roll_loss.append(loss.item())
            # display loss
            it.set_postfix_str(
                f"{'' if learn else 'val '}loss: {sum(roll_loss)/len(roll_loss):.2f}, "
                f"pos dist: {sum(roll_pd)/len(roll_pd):.2f}, "
                f"neg dist: {sum(roll_nd)/len(roll_nd):.2f}")
        return sum(roll_loss) / len(roll_loss), sum(roll_pd) / len(
            roll_pd), sum(roll_nd) / len(roll_nd)

    def criterion(self, pd, nd):
        return torch.clamp_min(pd - nd + self.margin, 0).mean()

    def train(self, train, valid, dataset: str):
        path = "Models/TransE/save.pt"

        # prepare the data
        self.emap = IMap()
        self.rmap = IMap()
        self.h2t = defaultdict(list)
        for h, r, t in train:
            self.emap.put(h)
            self.emap.put(t)
            self.rmap.put(r)
            self.h2t[h].append(self.emap[t])
        for h, tails in self.h2t.items():
            self.h2t[h] = torch.tensor(tails)
        train_batch = data.DataLoader(TripleData(train, self.emap, self.rmap),
                                      batch_size=self.batch_size)
        valid_batch = data.DataLoader(TripleData(valid, self.emap, self.rmap),
                                      batch_size=self.batch_size)

        # prepare the model
        self.module = TransEModule(len(self.emap),
                                   len(self.rmap),
                                   dim=self.dim).to(self.device)
        self.optimizer = optim.Adam(self.module.parameters(), lr=self.lr)

        # train it
        epoch = 1
        best_val = float("+inf")
        patience = 5
        p = patience
        print(f"Early stopping with patience {patience}")
        while p > 0:
            print(f"Epoch {epoch}")

            # training
            self.module.train()
            train_it = tqdm(train_batch, desc="\tTraining", file=sys.stdout)
            self.epoch(train_it)

            # validation
            self.module.eval()
            valid_it = tqdm(valid_batch, desc="\tValidating", file=sys.stdout)
            with torch.no_grad():
                v_loss, v_pd, v_nd = self.epoch(valid_it, learn=False)
            if v_loss < best_val:
                torch.save(self.module, path)
                best_val = v_loss
                p = patience
            else:
                p -= 1
            epoch += 1
            print()
        print(
            f"Loading best val loss = {best_val:.2f} at epoch {epoch-patience-1}"
        )
        # self.module = torch.load(path)
        self.inspect_embeddings()

    def load(self, train, valid, dataset: str):
        # prepare the data
        self.emap = IMap()
        self.rmap = IMap()
        self.h2t = defaultdict(list)
        for h, r, t in train:
            self.emap.put(h)
            self.emap.put(t)
            self.rmap.put(r)
            self.h2t[h].append(self.emap[t])
        for h, tails in self.h2t.items():
            self.h2t[h] = torch.tensor(tails)
        self.module = torch.load(self.path)
        self.optimizer = optim.Adam(self.module.parameters(), lr=self.lr)
        valid_batch = data.DataLoader(TripleData(valid, self.emap, self.rmap),
                                      batch_size=self.batch_size)
        valid_it = tqdm(valid_batch,
                        ncols=140,
                        desc="\tValidating",
                        file=sys.stdout)
        with torch.no_grad():
            self.epoch(valid_it, learn=False)
        self.inspect_embeddings()

    def link_completion(self, n, couples) -> List[List[Tuple[str, int]]]:
        preds = []
        idx2e = list(self.emap.keys())
        self.module.eval()
        with torch.no_grad():
            for h, r in couples:
                # get predictions
                hid = torch.tensor([self.emap[h]], device=self.device)
                rid = torch.tensor([self.rmap[r]], device=self.device)
                d = self.module.e_embed(hid) + self.module.r_embed(rid)
                # find the closest embeddings
                distances = torch.norm(self.module.e_embed.weight -
                                       d.view(-1)[None, :],
                                       dim=1)
                # filter out the direct connexions to boost accuracy
                distances[self.h2t[h]] = float("+inf")
                vals, indices = distances.topk(k=n, largest=False)
                preds.append([idx2e[i] for i in indices.flatten().tolist()])
        return preds