Exemplo n.º 1
0
    def forward(self,
                q: QTensor,
                edge_index: Adj,
                edge_attr: QTensor,
                size: Size = None) -> QTensor:

        assert edge_attr.__class__.__name__ == "QTensor"
        x = q.clone()
        # "cast" QTensor back to torch.Tensor
        q = q.stack(dim=1)  # (batch_num_nodes, 4, feature_dim)
        q = q.reshape(q.size(0), -1)  # (batch_num_nodes, 4*feature_dim)

        edge_attr = edge_attr.stack(dim=1)
        edge_attr = edge_attr.reshape(edge_attr.size(0), -1)

        # propagate
        agg = self.propagate(edge_index=edge_index,
                             x=q,
                             edge_attr=edge_attr,
                             size=size)
        agg = agg.reshape(agg.size(0), 4, -1).permute(1, 0, 2)

        q = QTensor(*agg)

        if self.same_dim:  # aggregate messages -> linearly transform -> add self-loops.
            q = self.transform(q)
            if self.add_self_loops:
                q += x
        else:
            if self.add_self_loops:  # aggregate messages -> add self-loops -> linearly transform.
                q += x
            q = self.transform(q)

        return q
Exemplo n.º 2
0
    def forward(self,
                q: QTensor,
                edge_index: Adj,
                edge_attr: QTensor,
                size: Size = None) -> QTensor:

        assert edge_attr.__class__.__name__ == "QTensor"
        x = q.clone()
        # "cast" QTensor back to torch.Tensor
        q = q.stack(dim=1)  # (batch_num_nodes, 4, feature_dim)
        q = q.reshape(q.size(0), -1)  # (batch_num_nodes, 4*feature_dim)

        edge_attr = edge_attr.stack(dim=1)
        edge_attr = edge_attr.reshape(edge_attr.size(0), -1)

        # propagate
        agg = self.propagate(edge_index=edge_index,
                             x=q,
                             edge_attr=edge_attr,
                             size=size)
        agg = agg.reshape(agg.size(0), 4, -1).permute(1, 0, 2)

        agg = QTensor(*agg)

        if self.add_self_loops:
            x += agg

        # transform aggregated node embeddings
        q = self.transform(x)
        return q
Exemplo n.º 3
0
    def test_quaternion_dropout(self):
        batch_size = 128
        in_features = 32
        same = False
        p = 0.3
        q = QTensor(*torch.randn(4, batch_size, in_features))
        q_dropped = quaternion_dropout(q=q, p=p, training=True, same=same)

        q_tensor = q.stack(dim=0)
        q_dropped_tensor = q_dropped.stack(dim=0)
        # check that "on"-indices are the same when retrieving the data
        ids = (q_dropped_tensor != 0.0)
        q_on = q_tensor[ids]
        q_dropped_on = q_dropped_tensor[ids]
        q_dropped_on *= (1-p)  # rescaling

        assert torch.allclose(q_on, q_dropped_on)

        same = True
        q = QTensor(*torch.randn(4, batch_size, in_features))
        q_dropped = quaternion_dropout(q=q, p=p, training=True, same=same)

        q_tensor = q.stack(dim=0)
        q_dropped_tensor = q_dropped.stack(dim=0)
        # rescaling
        q_dropped_tensor *= (1-p)

        # check if quaternion-component axis is really 0 among all components
        ids = [(x != 0.0).to(torch.float32) for x in q_dropped_tensor]

        for a, b in permutations(ids, 2):
            assert torch.allclose(a, b)
Exemplo n.º 4
0
    def test_qtensor_scatter_idx(self):

        row_ids = 1024
        idx = torch.randint(low=0,
                            high=256,
                            size=(row_ids, ),
                            dtype=torch.int64)
        p = 64
        x = QTensor(*torch.randn(4, row_ids, p))

        x_tensor = x.stack(dim=1)
        assert x_tensor.size() == torch.Size([row_ids, 4, p])

        x_aggr = scatter_sum(src=x_tensor,
                             index=idx,
                             dim=0,
                             dim_size=x_tensor.size(0))

        assert x_aggr.size() == x_tensor.size()
        x_aggr = x_aggr.permute(1, 0, 2)
        q_aggr = QTensor(*x_aggr)

        r = scatter_sum(x.r, idx, dim=0, dim_size=x.size(0))
        i = scatter_sum(x.i, idx, dim=0, dim_size=x.size(0))
        j = scatter_sum(x.j, idx, dim=0, dim_size=x.size(0))
        k = scatter_sum(x.k, idx, dim=0, dim_size=x.size(0))
        q_aggr2 = QTensor(r, i, j, k)

        assert q_aggr == q_aggr2
Exemplo n.º 5
0
 def forward(self, x: QTensor, idx: torch.Tensor, dim: int, dim_size: Optional[int] = None) -> QTensor:
     x = x.stack(dim=1)  # (num_nodes_batch, 4, feature_dim)
     weights = torch_scatter.composite.scatter_softmax(src=self.beta * x, index=idx, dim=dim)
     x = weights * x
     x = torch_scatter.scatter(src=x, index=idx, dim=dim, dim_size=dim_size, reduce="sum")
     # (num_nodes_batch, 4, feature_dim)
     x = x.permute(1, 0, 2)  # (4, num_nodes_batch, feature_dim)
     return QTensor(*x)
Exemplo n.º 6
0
def quaternion_batch_norm(
    qtensor: QTensor,
    running_mean,
    running_var,
    weight=None,
    bias=None,
    training=True,
    momentum=0.1,
    eps=1e-05,
) -> QTensor:
    """ Functional implementation of quaternion batch normalization """
    # check arguments
    assert ((running_mean is None and running_var is None)
            or (running_mean is not None and running_var is not None))
    assert ((weight is None and bias is None)
            or (weight is not None and bias is not None))

    # stack qtensor along the first dimension
    x = qtensor.stack(dim=0)
    # whiten and apply affine transformation
    z = whiten4x4(q=x,
                  training=training,
                  running_mean=running_mean,
                  running_cov=running_var,
                  momentum=momentum,
                  nugget=eps)

    p = x.size(-1)
    if weight is not None and bias is not None:
        shape = (1, p)
        weight = weight.reshape(4, 4, *shape)
        """ this is just the scaling formula
         x_r_BN = gamma_rr * x_r + gamma_ri * x_i + gamma_rj * x_j + gamma_rk * x_k + beta_r
         x_i_BN = gamma_ir * x_r + gamma_ii * x_i + gamma_ij * x_j + gamma_ik * x_k + beta_i
         x_j_BN = gamma_jr * x_r + gamma_ji * x_i + gamma_jj * x_j + gamma_jk * x_k + beta_j
         x_k_BN = gamma_kr * x_r + gamma_ki * x_i + gamma_kj * x_j + gamma_kk * x_k + beta_k
        """

        z = torch.stack([
            z[0] * weight[0, 0] + z[1] * weight[0, 1] + z[2] * weight[0, 2] +
            z[3] * weight[0, 3],
            z[0] * weight[1, 0] + z[1] * weight[1, 1] + z[2] * weight[1, 2] +
            z[3] * weight[1, 3],
            z[0] * weight[2, 0] + z[1] * weight[2, 1] + z[2] * weight[2, 2] +
            z[3] * weight[2, 3],
            z[0] * weight[3, 0] + z[1] * weight[3, 1] + z[2] * weight[3, 2] +
            z[3] * weight[3, 3],
        ],
                        dim=0) + bias.reshape(4, *shape)

    return QTensor(z[0], z[1], z[2], z[3])
Exemplo n.º 7
0
def qrelu_naive(q: QTensor) -> QTensor:
    r"""
    quaternion relu activation function f(z), where z = a + b*i + c*j + d*k
    a,b,c,d are real scalars where the last three scalars correspond to the vectorial part of the quaternion number

    f(z) returns z, iif a + b + c + d > 0.
    Otherwise it returns 0
    Note that f(z) is applied for each dimensionality of q, as in real-valued fashion.

    :param q: quaternion tensor of shape (b, d) where b is the batch-size and d the dimensionality
    :return: activated quaternion tensor
    """
    q = q.stack(dim=0)
    sum = q.sum(dim=0)
    a = torch.heaviside(sum, values=torch.zeros(sum.size()).to(q.device))
    a = a.expand_as(q)
    q = q * a
    return QTensor(*q)
Exemplo n.º 8
0
def quaternion_dropout(q: QTensor, p: float = 0.2, training: bool = True, same: bool = False) -> QTensor:
    assert 0.0 <= p <= 1.0, f"dropout rate must be in [0.0 ; 1.0]. {p} was inserted!"
    r"""
    Applies the same dropout mask for each quaternion component tensor of size [num_batch_nodes, d]
    along the same dimension d for the real and three hypercomplex parts.
    :param q: quaternion tensor with real part r and three hypercomplex parts i,j,k
    :param p: dropout rate. Must be within [0.0 ; 1.0]. If p=0.0, this function returns the input tensors 
    :param training: boolean flag if the dropout is used in training mode
                     Only if this is True, the dropout will be applied. Otherwise it will return the input tensors
    :return: (droped-out) quaternion q
    """
    if training and p > 0.0:
        q = q.stack(dim=0)
        if same:
            mask = get_bernoulli_mask(x=q[0], p=p).unsqueeze(dim=0)
            q = torch_dropout(x=q, p=p, mask=mask)
        else:
            q = F.dropout(q, p=p, training=training)
        return QTensor(*q)
    else:
        return q
Exemplo n.º 9
0
    def test_scatter_batch_idx(self):

        n_graphs = 128
        n_nodes = 2048
        idx = torch.randint(low=0,
                            high=n_graphs,
                            size=(n_nodes, ),
                            dtype=torch.int64)
        p = 64
        x = QTensor(*torch.randn(4, n_nodes, p))

        x_tensor = x.stack(dim=1)
        assert x_tensor.size() == torch.Size([n_nodes, 4, p])

        x_aggr = scatter_sum(src=x_tensor, index=idx, dim=0)
        x_aggr2 = global_add_pool(x_tensor, batch=idx)
        assert torch.allclose(x_aggr, x_aggr2)

        x_aggr = x_aggr.permute(1, 0, 2)
        q_aggr = QTensor(*x_aggr)

        r = scatter_sum(x.r, idx, dim=0)
        i = scatter_sum(x.i, idx, dim=0)
        j = scatter_sum(x.j, idx, dim=0)
        k = scatter_sum(x.k, idx, dim=0)
        q_aggr2 = QTensor(r, i, j, k)

        assert q_aggr == q_aggr2
        assert torch.allclose(x_aggr[0], r)
        assert torch.allclose(x_aggr[1], i)
        assert torch.allclose(x_aggr[2], j)
        assert torch.allclose(x_aggr[3], k)

        r1 = global_add_pool(x.r, idx)
        i1 = global_add_pool(x.i, idx)
        j1 = global_add_pool(x.j, idx)
        k1 = global_add_pool(x.k, idx)
        q_aggr3 = QTensor(r1, i1, j1, k1)

        assert q_aggr == q_aggr2 == q_aggr3
Exemplo n.º 10
0
 def __call__(self, x: QTensor, batch: Batch) -> QTensor:
     x_tensor = x.stack(dim=1)  # transform to torch.Tensor
     pooled = self.module(x=x_tensor, batch=batch)  # apply global pooling
     pooled = pooled.permute(1, 0, 2)  # permute such that first dimension is (4,*,*)
     return QTensor(*pooled)
Exemplo n.º 11
0
 def forward(self, x: QTensor, idx: torch.Tensor, dim: int, dim_size: Optional[int] = None) -> QTensor:
     x_tensor = x.stack(dim=1)  # transform to torch.Tensor  (*, 4, *)
     aggr = torch_scatter.scatter(src=x_tensor, index=idx, dim=dim, dim_size=dim_size, reduce=self.reduce)
     aggr = aggr.permute(1, 0, 2)  # permute such that first dimension is (4,*,*)
     return QTensor(*aggr)