def get_e_pos_neg(self, edge_pred_dict, training=True):
        """
        Align e_pos and e_neg to shape (num_edge, ) and (num_edge, num_nodes_neg). Ignores reverse edges
        :param edge_pred_dict:
        :return:
        """
        e_pos = torch.cat([e_pred for metapath, e_pred in edge_pred_dict.items() \
                           if not is_negative(metapath) and metapath in self.dataset.metapaths], dim=0)
        e_neg = torch.cat([e_pred for metapath, e_pred in edge_pred_dict.items() if
                           is_negative(metapath) and untag_negative(metapath) in self.dataset.metapaths], dim=0)

        if training:
            num_nodes_neg = int(self.hparams.neg_sampling_ratio)
        else:
            num_nodes_neg = int(e_neg.numel() // e_pos.numel())

        if e_neg.size(0) % num_nodes_neg:
            e_neg = e_neg[:e_neg.size(0) - e_neg.size(0) % num_nodes_neg]
        e_neg = e_neg.view(-1, num_nodes_neg)

        # ensure same num_edge in dim 0
        min_idx = min(e_pos.size(0), e_neg.size(0))
        e_pos = e_pos[:min_idx]
        e_neg = e_neg[:min_idx]

        return e_pos, e_neg
    def add_reverse_edge_index(edge_index_dict) -> None:
        reverse_edge_index_dict = {}
        for metapath in edge_index_dict:
            if is_negative(metapath) or edge_index_dict[metapath] == None: continue
            reverse_metapath = HeteroNetDataset.get_reverse_metapath_name(metapath, edge_index_dict)

            reverse_edge_index_dict[reverse_metapath] = edge_index_dict[metapath][[1, 0], :]
        edge_index_dict.update(reverse_edge_index_dict)
Beispiel #3
0
    def process_edge_reltype_dataset(self, dataset: PygLinkPropPredDataset):
        data = dataset[0]
        self._name = dataset.name
        self.edge_reltype = data.edge_reltype

        if hasattr(data, "num_nodes_dict"):
            self.num_nodes_dict = data.num_nodes_dict
        elif not hasattr(data, "edge_index_dict"):
            self.head_node_type = "entity"
            self.num_nodes_dict = {
                self.head_node_type: data.edge_index.max().item() + 1
            }

        if self.node_types is None:
            self.node_types = list(self.num_nodes_dict.keys())

        if hasattr(data, "x") and data.x is not None:
            self.x_dict = {self.head_node_type: data.x}
        elif hasattr(data, "x_dict") and data.x_dict is not None:
            self.x_dict = data.x_dict
        else:
            self.x_dict = {}

        self.metapaths = [(self.head_node_type, str(k.item()),
                           self.head_node_type)
                          for k in self.edge_reltype.unique()]
        self.edge_index_dict = {k: None for k in self.metapaths}

        split_idx = dataset.get_edge_split()
        train_triples, valid_triples, test_triples = split_idx[
            "train"], split_idx["valid"], split_idx["test"]
        self.triples = {}
        for key in train_triples.keys():
            if isinstance(train_triples[key], torch.Tensor):
                self.triples[key] = torch.cat([
                    valid_triples[key], test_triples[key], train_triples[key]
                ],
                                              dim=0)
            else:
                self.triples[key] = np.array(valid_triples[key] +
                                             test_triples[key] +
                                             train_triples[key])

        for key in valid_triples.keys():
            if is_negative(key):  # either head_neg or tail_neg
                self.triples[key] = torch.cat(
                    [valid_triples[key], test_triples[key]], dim=0)

        self.start_idx = {
            "valid": 0,
            "test": len(valid_triples["relation"]),
            "train":
            len(valid_triples["relation"]) + len(test_triples["relation"])
        }

        self.validation_idx = torch.arange(
            self.start_idx["valid"],
            self.start_idx["valid"] + len(valid_triples["relation"]))
        self.testing_idx = torch.arange(
            self.start_idx["test"],
            self.start_idx["test"] + len(test_triples["relation"]))
        self.training_idx = torch.arange(
            self.start_idx["train"],
            self.start_idx["train"] + len(train_triples["relation"]))

        assert self.validation_idx.max() < self.testing_idx.min()
        assert self.testing_idx.max() < self.training_idx.min()
Beispiel #4
0
    def sample(self, iloc):
        if not isinstance(iloc, torch.Tensor):
            iloc = torch.tensor(iloc)

        # Add neg edges if valid or test
        if iloc.max() < self.start_idx["train"]:
            has_neg_edges = True
        else:
            has_neg_edges = False

        X = {"edge_index_dict": {}, "global_node_index": {}, "x_dict": {}}

        triples = {
            k: v[iloc]
            for k, v in self.triples.items() if not is_negative(k)
        }
        if has_neg_edges:
            triples.update({
                k: v[iloc]
                for k, v in self.triples.items() if is_negative(k)
            })

        # Gather all nodes sampled
        relation_ids_all = triples["relation"].unique()
        for relation_id in relation_ids_all:
            metapath = self.metapaths[relation_id]
            head_type, tail_type = metapath[0], metapath[-1]

            mask = triples["relation"] == relation_id
            X["global_node_index"].setdefault(head_type,
                                              []).append(triples["head"][mask])
            X["global_node_index"].setdefault(tail_type,
                                              []).append(triples["tail"][mask])
            if has_neg_edges:
                X["global_node_index"].setdefault(head_type, []).append(
                    triples["head_neg"][mask].view(-1))
                X["global_node_index"].setdefault(tail_type, []).append(
                    triples["tail_neg"][mask].view(-1))

        X["global_node_index"] = {node_type: torch.cat(node_sets, dim=0).unique() \
                                  for node_type, node_sets in X["global_node_index"].items()}

        local2batch = {
            node_type: dict(
                zip(X["global_node_index"][node_type].numpy(),
                    range(len(X["global_node_index"][node_type]))))
            for node_type in X["global_node_index"]
        }

        # Get edge_index with batch id
        for relation_id in relation_ids_all:
            metapath = self.metapaths[relation_id]
            head_type, tail_type = metapath[0], metapath[-1]

            mask = triples["relation"] == relation_id
            sources = triples["head"][mask].apply_(local2batch[head_type].get)
            targets = triples["tail"][mask].apply_(local2batch[tail_type].get)
            X["edge_index_dict"][metapath] = torch.stack([sources, targets],
                                                         dim=1).t()

            if has_neg_edges:
                head_neg = triples["head_neg"][mask].apply_(
                    local2batch[head_type].get)
                tail_neg = triples["tail_neg"][mask].apply_(
                    local2batch[tail_type].get)
                head_batch = torch.stack(
                    [head_neg.view(-1),
                     targets.repeat(head_neg.size(1))])
                tail_batch = torch.stack(
                    [sources.repeat(tail_neg.size(1)),
                     tail_neg.view(-1)])
                X["edge_index_dict"][tag_negative(metapath)] = torch.cat(
                    [head_batch, tail_batch], dim=1)

        if self.use_reverse:
            self.add_reverse_edge_index(X["edge_index_dict"])

        # Make x_dict
        if hasattr(self, "x_dict") and len(self.x_dict) > 0:
            X["x_dict"] = {node_type: self.x_dict[node_type][X["global_node_index"][node_type]] \
                           for node_type in self.x_dict}

        return X, None, None