Esempio n. 1
0
    def loss(self, data, split="train"):
        if split == "train":
            mask = data.train_mask
        elif split == "val":
            mask = data.val_mask
        else:
            mask = data.test_mask
        edge_index, edge_types = data.edge_index[:, mask], data.edge_attr[mask]

        self.get_edge_set(edge_index, edge_types)
        batch_edges, batch_attr, samples, rels, labels = sampling_edge_uniform(
            edge_index,
            edge_types,
            self.edge_set,
            self.sampling_rate,
            self.num_rels,
            label_smoothing=self.lbl_smooth,
            num_entities=self.num_entities)
        node_embed, rel_embed = self.forward(batch_edges, batch_attr)

        sampled_nodes, reindexed_edges = torch.unique(samples,
                                                      sorted=True,
                                                      return_inverse=True)
        assert (self.cache_index == sampled_nodes).any()
        loss_n = self._loss(node_embed[reindexed_edges[0]],
                            node_embed[reindexed_edges[1]], rel_embed[rels],
                            labels)
        loss_r = self.penalty * self._regularization(
            [self.emb(sampled_nodes), rel_embed])
        return loss_n + loss_r
Esempio n. 2
0
    def loss(self, data, split="train"):
        if split == "train":
            mask = data.train_mask
        elif split == "val":
            mask = data.val_mask
        else:
            mask = data.test_mask
        edge_index, edge_types = data.edge_index[:, mask], data.edge_attr[mask]

        self.get_edge_set(edge_index, edge_types)
        batch_edges, batch_attr, samples, rels, labels = sampling_edge_uniform(
            edge_index, edge_types, self.edge_set, self.sampling_rate,
            self.num_rels)
        output = self.forward(batch_edges, batch_attr)
        edge_weight = self.rel_weight(rels)
        sampled_nodes, reindexed_edges = torch.unique(samples,
                                                      sorted=True,
                                                      return_inverse=True)
        assert (sampled_nodes == self.cahce_index).any()
        sampled_types = torch.unique(rels)

        loss_n = self._loss(
            output[reindexed_edges[0]], output[reindexed_edges[1]],
            edge_weight, labels) + self.penalty * self._regularization(
                [self.emb(sampled_nodes),
                 self.rel_weight(sampled_types)])
        return loss_n
Esempio n. 3
0
    def loss(self, data: Graph, split="train"):
        if split == "train":
            mask = data.train_mask
        elif split == "val":
            mask = data.val_mask
        else:
            mask = data.test_mask
        row, col = data.edge_index
        row, col = row[mask], col[mask]
        edge_types = data.edge_attr[mask]
        edge_index = torch.stack([row, col])

        self.get_edge_set(edge_index, edge_types)
        batch_edges, batch_attr, samples, rels, labels = sampling_edge_uniform(
            (row, col),
            edge_types,
            self.edge_set,
            self.sampling_rate,
            self.num_rels,
            label_smoothing=self.lbl_smooth,
            num_entities=self.num_entities,
        )
        with data.local_graph():
            data.edge_index = batch_edges
            data.edge_attr = batch_attr
            node_embed, rel_embed = self.forward(data)

        sampled_nodes, reindexed_edges = torch.unique(samples,
                                                      sorted=True,
                                                      return_inverse=True)
        assert (self.cache_index == sampled_nodes).any()
        loss_n = self._loss(node_embed[reindexed_edges[0]],
                            node_embed[reindexed_edges[1]], rel_embed[rels],
                            labels)
        loss_r = self.penalty * self._regularization(
            [self.emb(sampled_nodes), rel_embed])
        return loss_n + loss_r