Exemplo n.º 1
0
    def forward(self,
                q: QTensor,
                edge_index: Adj,
                edge_attr: QTensor,
                size: Size = None) -> QTensor:

        assert edge_attr.__class__.__name__ == "QTensor"
        x = q.clone()
        # "cast" QTensor back to torch.Tensor
        q = q.stack(dim=1)  # (batch_num_nodes, 4, feature_dim)
        q = q.reshape(q.size(0), -1)  # (batch_num_nodes, 4*feature_dim)

        edge_attr = edge_attr.stack(dim=1)
        edge_attr = edge_attr.reshape(edge_attr.size(0), -1)

        # propagate
        agg = self.propagate(edge_index=edge_index,
                             x=q,
                             edge_attr=edge_attr,
                             size=size)
        agg = agg.reshape(agg.size(0), 4, -1).permute(1, 0, 2)

        agg = QTensor(*agg)

        if self.add_self_loops:
            x += agg

        # transform aggregated node embeddings
        q = self.transform(x)
        return q
Exemplo n.º 2
0
    def forward(self, x: QTensor, verbose=False) -> torch.Tensor:
        # forward pass
        for i in range(len(self.affine)):

            if verbose:
                print(f"iteration {i}")
                print("input:", x.size())
                print("affine", self.affine[i])
            x = self.affine[i](x)
            # print("out affine:", x.size())
            if i < len(
                    self.affine
            ) - 1:  # only for input->hidden and hidden layers, but not output
                if self.norm_flag:
                    if verbose:
                        print("normalization")
                        print("activation")
                    x = self.norm[i](x)
                    x = self.activation_func(x)
                else:
                    if verbose:
                        print("activation")
                    x = self.activation_func(x)
                if self.training and self.dropout[i] > 0.0:  # and i > 0:
                    if verbose:
                        print("dropout")
                    x = quaternion_dropout(x,
                                           p=self.dropout[i],
                                           training=self.training,
                                           same=self.same_dropout)
            if verbose:
                print("output:", x.size())
        # at the end, transform the quaternion output vector to a real vector
        x = self.real_trafo(x)
        return x
Exemplo n.º 3
0
    def forward(self,
                q: QTensor,
                edge_index: Adj,
                edge_attr: QTensor,
                size: Size = None) -> QTensor:

        assert edge_attr.__class__.__name__ == "QTensor"
        x = q.clone()
        # "cast" QTensor back to torch.Tensor
        q = q.stack(dim=1)  # (batch_num_nodes, 4, feature_dim)
        q = q.reshape(q.size(0), -1)  # (batch_num_nodes, 4*feature_dim)

        edge_attr = edge_attr.stack(dim=1)
        edge_attr = edge_attr.reshape(edge_attr.size(0), -1)

        # propagate
        agg = self.propagate(edge_index=edge_index,
                             x=q,
                             edge_attr=edge_attr,
                             size=size)
        agg = agg.reshape(agg.size(0), 4, -1).permute(1, 0, 2)

        q = QTensor(*agg)

        if self.same_dim:  # aggregate messages -> linearly transform -> add self-loops.
            q = self.transform(q)
            if self.add_self_loops:
                q += x
        else:
            if self.add_self_loops:  # aggregate messages -> add self-loops -> linearly transform.
                q += x
            q = self.transform(q)

        return q
Exemplo n.º 4
0
    def test_qtensor_scatter_idx(self):

        row_ids = 1024
        idx = torch.randint(low=0,
                            high=256,
                            size=(row_ids, ),
                            dtype=torch.int64)
        p = 64
        x = QTensor(*torch.randn(4, row_ids, p))

        x_tensor = x.stack(dim=1)
        assert x_tensor.size() == torch.Size([row_ids, 4, p])

        x_aggr = scatter_sum(src=x_tensor,
                             index=idx,
                             dim=0,
                             dim_size=x_tensor.size(0))

        assert x_aggr.size() == x_tensor.size()
        x_aggr = x_aggr.permute(1, 0, 2)
        q_aggr = QTensor(*x_aggr)

        r = scatter_sum(x.r, idx, dim=0, dim_size=x.size(0))
        i = scatter_sum(x.i, idx, dim=0, dim_size=x.size(0))
        j = scatter_sum(x.j, idx, dim=0, dim_size=x.size(0))
        k = scatter_sum(x.k, idx, dim=0, dim_size=x.size(0))
        q_aggr2 = QTensor(r, i, j, k)

        assert q_aggr == q_aggr2