def __init__(self, in_features: int, out_features: int, bias: bool = True, init: str = "orthogonal") -> None: # Initialize empty weight matrices for real part (r) and three complex parts (i,j,k) super(QLinear, self).__init__() assert init in ["glorot-normal", "glorot-uniform", "quaternion", "orthogonal"] self.in_features = in_features self.out_features = out_features self.bias = bias self.init = init self.W_r = nn.Parameter(torch.Tensor(self.out_features, self.in_features), requires_grad=True) self.W_i = nn.Parameter(torch.Tensor(self.out_features, self.in_features), requires_grad=True) self.W_j = nn.Parameter(torch.Tensor(self.out_features, self.in_features), requires_grad=True) self.W_k = nn.Parameter(torch.Tensor(self.out_features, self.in_features), requires_grad=True) if self.bias: self.b_r = nn.Parameter(torch.Tensor(self.out_features), requires_grad=True) self.b_i = nn.Parameter(torch.Tensor(self.out_features), requires_grad=True) self.b_j = nn.Parameter(torch.Tensor(self.out_features), requires_grad=True) self.b_k = nn.Parameter(torch.Tensor(self.out_features), requires_grad=True) self.reset_parameters() # save the weights into QTensors self.W = QTensor(self.W_r, self.W_i, self.W_j, self.W_k) if self.bias: self.b = QTensor(self.b_r, self.b_i, self.b_j, self.b_k)
def forward(self, x: QTensor, verbose=False) -> torch.Tensor: # forward pass for i in range(len(self.affine)): if verbose: print(f"iteration {i}") print("input:", x.size()) print("affine", self.affine[i]) x = self.affine[i](x) # print("out affine:", x.size()) if i < len( self.affine ) - 1: # only for input->hidden and hidden layers, but not output if self.norm_flag: if verbose: print("normalization") print("activation") x = self.norm[i](x) x = self.activation_func(x) else: if verbose: print("activation") x = self.activation_func(x) if self.training and self.dropout[i] > 0.0: # and i > 0: if verbose: print("dropout") x = quaternion_dropout(x, p=self.dropout[i], training=self.training, same=self.same_dropout) if verbose: print("output:", x.size()) # at the end, transform the quaternion output vector to a real vector x = self.real_trafo(x) return x
def forward(self, x: QTensor, idx: torch.Tensor, dim: int, dim_size: Optional[int] = None) -> QTensor: x = x.stack(dim=1) # (num_nodes_batch, 4, feature_dim) weights = torch_scatter.composite.scatter_softmax(src=self.beta * x, index=idx, dim=dim) x = weights * x x = torch_scatter.scatter(src=x, index=idx, dim=dim, dim_size=dim_size, reduce="sum") # (num_nodes_batch, 4, feature_dim) x = x.permute(1, 0, 2) # (4, num_nodes_batch, feature_dim) return QTensor(*x)
def test_simple_mul(self): # real quaternion tensor addition t1 = QTensor(r=torch.tensor([1.0, 2.0, 3.0, -1.0]), i=torch.tensor([2.0, 2.0, 3.0, -2.0]), j=torch.tensor([2.0, 1.0, 0.0, 1.5]), k=torch.tensor([5.0, 4.0, 3.0, 5.0])) t2 = QTensor(r=torch.tensor([2.0, 3.0, 4.0, -0.5]), i=torch.tensor([3.0, 1.0, 2.0, 2.0]), j=torch.tensor([1.0, 0.0, 1.0, -3]), k=torch.tensor([4.0, 3.0, 2.0, -4])) # t1 has shape (4) and t2 has shape (4) # hamilton product is done element-wise similar to the hadamard product which applies the # multiplication element-wise for the vector component t3 = t1 * t2 # has shape (4) # alternative calculation of quaternion representation via dot product and cross product # https://www.3dgep.com/understanding-quaternions/#Quaternion_Products t1_vec = [] t2_vec = [] # retrieve the vector/imaginary parts ijk for s1, s2 in zip(t1, t2): t1_vec.append([s1.i, s1.j, s1.k]) t2_vec.append([s2.i, s2.j, s2.k]) t1_vec = torch.tensor(t1_vec) # (4,3) t2_vec = torch.tensor(t2_vec) # (4,3) cross_product = torch.cross(t1_vec, t2_vec, dim=1) cross_product_i = t1.j * t2.k - t2.j * t1.k cross_product_j = t1.k * t2.i - t2.k * t1.i cross_product_k = t1.i * t2.j - t2.i * t1.j cross_product_manual = torch.cat([ cross_product_i.view(-1, 1), cross_product_j.view(-1, 1), cross_product_k.view(-1, 1) ], dim=1) assert torch.allclose(cross_product, cross_product_manual) # get the real part from r = t1.r * t2.r - torch.sum(t1_vec * t2_vec, dim=1) # [4] # get the vector/imaginary parts ijk = t1.r.view(-1, 1) * t2_vec + t2.r.view( -1, 1) * t1_vec + cross_product # [4,3] r = r.unsqueeze(dim=1) rijk = torch.cat([r, ijk], dim=1).t() # (4,4) t3_tensor = t3.stack(dim=0) assert torch.allclose(t3_tensor, rijk) t3 = t1 * t2 t3_not = t2 * t1 assert not t3 == t3_not
def unitary_init(in_features, out_features, low=0, high=1, dtype=torch.float64) -> (Tensor, Tensor, Tensor, Tensor): # init in interval [low, high], i.e. with defaults: 0 and 1 v_r = torch.FloatTensor(in_features, out_features).to(dtype).zero_() v_i = torch.FloatTensor(in_features, out_features).to(dtype).uniform_(low, high) v_j = torch.FloatTensor(in_features, out_features).to(dtype).uniform_(low, high) v_k = torch.FloatTensor(in_features, out_features).to(dtype).uniform_(low, high) # Unitary quaternion q = QTensor(v_r, v_i, v_j, v_k) q_unitary = q.normalize() return q_unitary.r, q_unitary.i, q_unitary.j, q_unitary.k
def test_mul_grad(self): x = QTensor(r=torch.tensor([1.0]), i=torch.tensor([2.0]), j=torch.tensor([2.0]), k=torch.tensor([5.0])).requires_grad_() w = QTensor(r=torch.tensor([0.1]), i=torch.tensor([0.2]), j=torch.tensor([0.3]), k=torch.tensor([0.4])).requires_grad_() # f(w,x) = w*x y = w * x
def quaternion_batch_norm( qtensor: QTensor, running_mean, running_var, weight=None, bias=None, training=True, momentum=0.1, eps=1e-05, ) -> QTensor: """ Functional implementation of quaternion batch normalization """ # check arguments assert ((running_mean is None and running_var is None) or (running_mean is not None and running_var is not None)) assert ((weight is None and bias is None) or (weight is not None and bias is not None)) # stack qtensor along the first dimension x = qtensor.stack(dim=0) # whiten and apply affine transformation z = whiten4x4(q=x, training=training, running_mean=running_mean, running_cov=running_var, momentum=momentum, nugget=eps) p = x.size(-1) if weight is not None and bias is not None: shape = (1, p) weight = weight.reshape(4, 4, *shape) """ this is just the scaling formula x_r_BN = gamma_rr * x_r + gamma_ri * x_i + gamma_rj * x_j + gamma_rk * x_k + beta_r x_i_BN = gamma_ir * x_r + gamma_ii * x_i + gamma_ij * x_j + gamma_ik * x_k + beta_i x_j_BN = gamma_jr * x_r + gamma_ji * x_i + gamma_jj * x_j + gamma_jk * x_k + beta_j x_k_BN = gamma_kr * x_r + gamma_ki * x_i + gamma_kj * x_j + gamma_kk * x_k + beta_k """ z = torch.stack([ z[0] * weight[0, 0] + z[1] * weight[0, 1] + z[2] * weight[0, 2] + z[3] * weight[0, 3], z[0] * weight[1, 0] + z[1] * weight[1, 1] + z[2] * weight[1, 2] + z[3] * weight[1, 3], z[0] * weight[2, 0] + z[1] * weight[2, 1] + z[2] * weight[2, 2] + z[3] * weight[2, 3], z[0] * weight[3, 0] + z[1] * weight[3, 1] + z[2] * weight[3, 2] + z[3] * weight[3, 3], ], dim=0) + bias.reshape(4, *shape) return QTensor(z[0], z[1], z[2], z[3])
def forward(self, q: QTensor, edge_index: Adj, edge_attr: QTensor, size: Size = None) -> QTensor: assert edge_attr.__class__.__name__ == "QTensor" x = q.clone() # "cast" QTensor back to torch.Tensor q = q.stack(dim=1) # (batch_num_nodes, 4, feature_dim) q = q.reshape(q.size(0), -1) # (batch_num_nodes, 4*feature_dim) edge_attr = edge_attr.stack(dim=1) edge_attr = edge_attr.reshape(edge_attr.size(0), -1) # propagate agg = self.propagate(edge_index=edge_index, x=q, edge_attr=edge_attr, size=size) agg = agg.reshape(agg.size(0), 4, -1).permute(1, 0, 2) q = QTensor(*agg) if self.same_dim: # aggregate messages -> linearly transform -> add self-loops. q = self.transform(q) if self.add_self_loops: q += x else: if self.add_self_loops: # aggregate messages -> add self-loops -> linearly transform. q += x q = self.transform(q) return q
def forward(self, q: QTensor, edge_index: Adj, edge_attr: QTensor, size: Size = None) -> QTensor: assert edge_attr.__class__.__name__ == "QTensor" x = q.clone() # "cast" QTensor back to torch.Tensor q = q.stack(dim=1) # (batch_num_nodes, 4, feature_dim) q = q.reshape(q.size(0), -1) # (batch_num_nodes, 4*feature_dim) edge_attr = edge_attr.stack(dim=1) edge_attr = edge_attr.reshape(edge_attr.size(0), -1) # propagate agg = self.propagate(edge_index=edge_index, x=q, edge_attr=edge_attr, size=size) agg = agg.reshape(agg.size(0), 4, -1).permute(1, 0, 2) agg = QTensor(*agg) if self.add_self_loops: x += agg # transform aggregated node embeddings q = self.transform(x) return q
def forward(self, x: QTensor, batch: Batch) -> QTensor: out = self.linear(x) # get logits out = self.real_trafo(out) # "transform" to real-valued out = self.sigmoid(out) # get "probabilities" x = QTensor(out * x.r, out * x.i, out * x.j, out * x.k) # explicitly writing out the hadamard product x = self.sum_pooling(x, batch) return x
def test_quaternion_batchnorm(self): batch_size = 128 in_channels = 16 q = QTensor(*torch.rand(4, batch_size, in_channels) * 2.0 + 5) # (128, 16) batch_norm = QuaternionNorm(num_features=in_channels, type="q-batch-norm", **{"affine": False}) batch_norm = batch_norm.train() y = batch_norm(q) # (128, 16) # each quaternion number has mean 0 and standard deviation 1 y_stacked = y.stack(dim=0) # (4, 128, 16) perm = y_stacked.permute(2, 0, 1) # [16, 4, 1] # covariances of shape (16, 4, 4) cov = torch.matmul(perm, perm.transpose( -1, -2)) / perm.shape[-1] # [16, 4, 4] eye_covs = torch.stack([torch.eye(4) for _ in range(cov.size(0))], dim=0) diffs = cov - eye_covs a = torch.abs(diffs).sum().item() assert 0.0001 < np.round(a, 4) < 0.05 # check mean assert abs(y_stacked.mean(1).norm().item()) < 1e-4
def forward(self, q: QTensor, edge_index: Adj, edge_attr: QTensor, size: Size = None) -> QTensor: assert edge_attr.__class__.__name__ == "QTensor" # propagate each part, i.e. real part and the three complex parts agg_r = self.propagate(edge_index=edge_index, x=q.r, edge_attr=edge_attr.r, size=size) # [b_num_nodes, in_features] agg_i = self.propagate(edge_index=edge_index, x=q.i, edge_attr=edge_attr.i, size=size) # [b_num_nodes, in_features] agg_j = self.propagate(edge_index=edge_index, x=q.j, edge_attr=edge_attr.j, size=size) # [b_num_nodes, in_features] agg_k = self.propagate(edge_index=edge_index, x=q.k, edge_attr=edge_attr.k, size=size) # [b_num_nodes, in_features] aggregated = QTensor(agg_r, agg_i, agg_j, agg_k) if self.add_self_loops: q += aggregated # transform aggregated node embeddings q = self.transform(q) return q
def test_qtensor_scatter_idx(self): row_ids = 1024 idx = torch.randint(low=0, high=256, size=(row_ids, ), dtype=torch.int64) p = 64 x = QTensor(*torch.randn(4, row_ids, p)) x_tensor = x.stack(dim=1) assert x_tensor.size() == torch.Size([row_ids, 4, p]) x_aggr = scatter_sum(src=x_tensor, index=idx, dim=0, dim_size=x_tensor.size(0)) assert x_aggr.size() == x_tensor.size() x_aggr = x_aggr.permute(1, 0, 2) q_aggr = QTensor(*x_aggr) r = scatter_sum(x.r, idx, dim=0, dim_size=x.size(0)) i = scatter_sum(x.i, idx, dim=0, dim_size=x.size(0)) j = scatter_sum(x.j, idx, dim=0, dim_size=x.size(0)) k = scatter_sum(x.k, idx, dim=0, dim_size=x.size(0)) q_aggr2 = QTensor(r, i, j, k) assert q_aggr == q_aggr2
def test_inverse_mul(self): t1 = QTensor(r=torch.tensor([1.0, 2.0, 3.0, -1.0]), i=torch.tensor([2.0, 2.0, 3.0, -2.0]), j=torch.tensor([2.0, 1.0, 0.0, 1.5]), k=torch.tensor([5.0, 4.0, 3.0, 5.0])) unit_real = t1 * t1.inverse() assert torch.allclose(unit_real.r, torch.ones_like(unit_real.r)) assert torch.allclose(unit_real.i, torch.zeros_like(unit_real.i)) assert torch.allclose(unit_real.j, torch.zeros_like(unit_real.j)) assert torch.allclose(unit_real.k, torch.zeros_like(unit_real.k)) # (p*q)^-1 = q^-1 * p^-1 t2 = QTensor(r=torch.tensor([2.0, 3.0, 4.0, -0.5]), i=torch.tensor([3.0, 1.0, 2.0, 2.0]), j=torch.tensor([1.0, 0.0, 1.0, -3]), k=torch.tensor([4.0, 3.0, 2.0, -4])) out1 = (t1 * t2).inverse() out2 = t2.inverse() * t1.inverse() assert torch.allclose(out1.r, out2.r) assert torch.allclose(out1.i, out2.i) assert torch.allclose(out1.j, out2.j) assert torch.allclose(out1.k, out2.k)
def qrelu_naive(q: QTensor) -> QTensor: r""" quaternion relu activation function f(z), where z = a + b*i + c*j + d*k a,b,c,d are real scalars where the last three scalars correspond to the vectorial part of the quaternion number f(z) returns z, iif a + b + c + d > 0. Otherwise it returns 0 Note that f(z) is applied for each dimensionality of q, as in real-valued fashion. :param q: quaternion tensor of shape (b, d) where b is the batch-size and d the dimensionality :return: activated quaternion tensor """ q = q.stack(dim=0) sum = q.sum(dim=0) a = torch.heaviside(sum, values=torch.zeros(sum.size()).to(q.device)) a = a.expand_as(q) q = q * a return QTensor(*q)
def test_naive_batchnorm(self): # Naive batch-normalization for quaternions and phm (with phm_dim=4) should be the same. phm_dim = 4 num_features = 16 batch_size = 128 quat_bn = QuaternionNorm(type="naive-batch-norm", num_features=num_features).train() phm_bn = PHMNorm(type="naive-batch-norm", phm_dim=phm_dim, num_features=num_features).train() distances = [] for i in range(500): x = torch.randn(phm_dim, batch_size, num_features) x_q = QTensor(*x) x_p = x.permute(1, 0, 2).reshape(batch_size, -1) y_quat = quat_bn(x_q) y_phm = phm_bn(x_p) y_quat = y_quat.stack(dim=1) y_quat = y_quat.reshape(batch_size, -1) d = (y_phm - y_quat).norm().item() distances.append(d) assert sum(distances) == 0.0 # eval mode uses running-mean and estimated weights + bias for rescaling and shifting. quat_bn = quat_bn.eval() phm_bn = phm_bn.eval() distances = [] for i in range(500): x = torch.randn(4, batch_size, num_features) x_q = QTensor(*x) x_p = x.permute(1, 0, 2).reshape(batch_size, -1) y_quat = quat_bn(x_q) y_phm = phm_bn(x_p) y_quat = y_quat.stack(dim=1) y_quat = y_quat.reshape(batch_size, -1) d = (y_phm - y_quat).norm().item() distances.append(d) assert sum(distances)/len(distances) < 1e-4
def forward(self, x: QTensor) -> torch.Tensor: """Computes the forward pass to convert a quaternion vector to a real vector""" if self.type == "sum": return x.r + x.i + x.j + x.k elif self.type == "mean": return (x.r + x.i + x.j + x.k).mean() elif self.type == "norm": return x.norm() else: x = torch.cat([x.r, x.i, x.j, x.k], dim=-1) return self.affine(x)
def quaternion_dropout(q: QTensor, p: float = 0.2, training: bool = True, same: bool = False) -> QTensor: assert 0.0 <= p <= 1.0, f"dropout rate must be in [0.0 ; 1.0]. {p} was inserted!" r""" Applies the same dropout mask for each quaternion component tensor of size [num_batch_nodes, d] along the same dimension d for the real and three hypercomplex parts. :param q: quaternion tensor with real part r and three hypercomplex parts i,j,k :param p: dropout rate. Must be within [0.0 ; 1.0]. If p=0.0, this function returns the input tensors :param training: boolean flag if the dropout is used in training mode Only if this is True, the dropout will be applied. Otherwise it will return the input tensors :return: (droped-out) quaternion q """ if training and p > 0.0: q = q.stack(dim=0) if same: mask = get_bernoulli_mask(x=q[0], p=p).unsqueeze(dim=0) q = torch_dropout(x=q, p=p, mask=mask) else: q = F.dropout(q, p=p, training=training) return QTensor(*q) else: return q
def test_simple_add_grad(self): # real quaternion tensor addition t1 = QTensor(r=torch.tensor([1.0, 2.0, 3.0]), i=torch.tensor([2.0, 2.0, 3.0]), j=torch.tensor([2.0, 1.0, 0.0]), k=torch.tensor([5.0, 4.0, 3.0])) t1 = t1.requires_grad_() t2 = QTensor(r=torch.tensor([2.0, 3.0, 4.0]), i=torch.tensor([3.0, 1.0, 2.0]), j=torch.tensor([1.0, 0.0, 1.0]), k=torch.tensor([4.0, 3.0, 2.0])) t2 = t2.requires_grad_() t3 = t1 + t2 t3.backward(torch.tensor([1.0, 1.0, 1.0])) # the gradient of a sum is equal to 1. # i.e. d(a + b) / da = 1 # i.e. d(a + b) / db = 1 # and since the "loss-gradient is [1.0, 1.0, 1.0] it gets multiplied with d/da and d/db respectively. assert torch.allclose(t1.r.grad, torch.tensor([1.0, 1.0, 1.0])) assert torch.allclose(t1.i.grad, torch.tensor([1.0, 1.0, 1.0])) assert torch.allclose(t1.j.grad, torch.tensor([1.0, 1.0, 1.0])) assert torch.allclose(t1.k.grad, torch.tensor([1.0, 1.0, 1.0])) assert torch.allclose(t2.r.grad, torch.tensor([1.0, 1.0, 1.0])) assert torch.allclose(t2.i.grad, torch.tensor([1.0, 1.0, 1.0])) assert torch.allclose(t2.j.grad, torch.tensor([1.0, 1.0, 1.0])) assert torch.allclose(t2.k.grad, torch.tensor([1.0, 1.0, 1.0]))
def test_quaternion_dropout(self): batch_size = 128 in_features = 32 same = False p = 0.3 q = QTensor(*torch.randn(4, batch_size, in_features)) q_dropped = quaternion_dropout(q=q, p=p, training=True, same=same) q_tensor = q.stack(dim=0) q_dropped_tensor = q_dropped.stack(dim=0) # check that "on"-indices are the same when retrieving the data ids = (q_dropped_tensor != 0.0) q_on = q_tensor[ids] q_dropped_on = q_dropped_tensor[ids] q_dropped_on *= (1-p) # rescaling assert torch.allclose(q_on, q_dropped_on) same = True q = QTensor(*torch.randn(4, batch_size, in_features)) q_dropped = quaternion_dropout(q=q, p=p, training=True, same=same) q_tensor = q.stack(dim=0) q_dropped_tensor = q_dropped.stack(dim=0) # rescaling q_dropped_tensor *= (1-p) # check if quaternion-component axis is really 0 among all components ids = [(x != 0.0).to(torch.float32) for x in q_dropped_tensor] for a, b in permutations(ids, 2): assert torch.allclose(a, b)
def forward(self, q: QTensor, edge_index: Adj, edge_attr: QTensor, size: Size = None) -> QTensor: assert edge_attr.__class__.__name__ == "QTensor" x = q.clone() # propagate each part, i.e. real part and the three complex parts agg_r = self.propagate(edge_index=edge_index, x=q.r, edge_attr=edge_attr.r, size=size) # [b_num_nodes, in_features] agg_i = self.propagate(edge_index=edge_index, x=q.i, edge_attr=edge_attr.i, size=size) # [b_num_nodes, in_features] agg_j = self.propagate(edge_index=edge_index, x=q.j, edge_attr=edge_attr.j, size=size) # [b_num_nodes, in_features] agg_k = self.propagate(edge_index=edge_index, x=q.k, edge_attr=edge_attr.k, size=size) # [b_num_nodes, in_features] q = QTensor(agg_r, agg_i, agg_j, agg_k) if self.same_dim: # aggregate messages -> linearly transform -> add self-loops. q = self.transform(q) if self.add_self_loops: q += x else: if self.add_self_loops: # aggregate messages -> add self-loops -> linearly transform. q += x q = self.transform(q) return q
def qrelu_naive2(q: QTensor) -> QTensor: r""" quaternion relu activation function f(z), where z = a + b*i + c*j + d*k a,b,c,d are real scalars where the last three scalars correspond to the vectorial part of the quaternion number f(z) returns z, iif a,b,c,d > 0. Otherwise it returns 0 Note that f(z) is applied for each dimensionality of q, as in real-valued fashion. :param q: quaternion tensor of shape (b, d) where b is the batch-size and d the dimensionality :return: activated quaternion tensor """ mask_r, mask_i, mask_j, mask_k = q.r > 0.0, q.i > 0.0, q.j > 0.0, q.k > 0.0 mask = mask_r * mask_i * mask_j * mask_k r, i, j, k = mask * q.r, mask * q.i, mask * q.j, mask * q.k return QTensor(r, i, j, k)
def test_naive_batch_norm(self): batch_size = 128 in_channels = 16 q = QTensor(*torch.rand(4, batch_size, in_channels) * 2.0 + 5) # (128, 16) batch_norm = QuaternionNorm(num_features=in_channels, type="naive-batch-norm") batch_norm = batch_norm.train() y = batch_norm(q) # assert that r,i,j,k components all have mean 0 separately mean_r = torch.mean(y.r, dim=0) # (16,) mean_i = torch.mean(y.i, dim=0) mean_j = torch.mean(y.j, dim=0) mean_k = torch.mean(y.k, dim=0) assert abs(mean_r.norm().item()) < 1e-5 assert abs(mean_i.norm().item()) < 1e-5 assert abs(mean_j.norm().item()) < 1e-5 assert abs(mean_k.norm().item()) < 1e-5 # assert that r,i,j,k components all have (biased) standard deviation of 1 separately std_r = torch.std(y.r, dim=0, unbiased=False) # (16, ) std_i = torch.std(y.i, dim=0, unbiased=False) std_j = torch.std(y.j, dim=0, unbiased=False) std_k = torch.std(y.k, dim=0, unbiased=False) assert abs(std_r - 1.0).sum().item() < 0.001 assert abs(std_i - 1.0).sum().item() < 0.001 assert abs(std_j - 1.0).sum().item() < 0.001 assert abs(std_k - 1.0).sum().item() < 0.001 # what about the covariance between for each quaternion number? y_stacked = y.stack(dim=0) # (4, 128, 16) perm = y_stacked.permute(2, 0, 1) # [16, 4, 1] cov = torch.matmul(perm, perm.transpose( -1, -2)) / perm.shape[-1] # [16, 4, 4] eye_covs = torch.stack([torch.eye(4) for _ in range(cov.size(0))], dim=0) diffs = cov - eye_covs a = torch.abs(diffs).sum().item() assert np.round(a, 4) > 1.0
def test_quaternion_hypercomplex_dropout(self): batch_size = 128 in_features = 32 phm_dim = 4 same = False p = 0.3 x = torch.randn(phm_dim, batch_size, in_features) q = QTensor(*x) q_dropped = quaternion_dropout(q, p=p, training=True, same=same) q_dropped = q_dropped.stack(dim=1) xx = x.permute(1, 0, 2).reshape(batch_size, -1) xx_dropped = phm_dropout(x=xx, p=p, training=True, same=same, phm_dim=phm_dim) xx_dropped = xx_dropped.reshape(batch_size, phm_dim, in_features) # check that the values, where quaternion-dropped and hypercomplex-dropped are "on", i.e. populated ids = (q_dropped != 0.0) * (xx_dropped != 0.0) on_q = q_dropped[ids] on_phm = xx_dropped[ids] assert torch.allclose(on_q, on_phm)
def test_scatter_batch_idx(self): n_graphs = 128 n_nodes = 2048 idx = torch.randint(low=0, high=n_graphs, size=(n_nodes, ), dtype=torch.int64) p = 64 x = QTensor(*torch.randn(4, n_nodes, p)) x_tensor = x.stack(dim=1) assert x_tensor.size() == torch.Size([n_nodes, 4, p]) x_aggr = scatter_sum(src=x_tensor, index=idx, dim=0) x_aggr2 = global_add_pool(x_tensor, batch=idx) assert torch.allclose(x_aggr, x_aggr2) x_aggr = x_aggr.permute(1, 0, 2) q_aggr = QTensor(*x_aggr) r = scatter_sum(x.r, idx, dim=0) i = scatter_sum(x.i, idx, dim=0) j = scatter_sum(x.j, idx, dim=0) k = scatter_sum(x.k, idx, dim=0) q_aggr2 = QTensor(r, i, j, k) assert q_aggr == q_aggr2 assert torch.allclose(x_aggr[0], r) assert torch.allclose(x_aggr[1], i) assert torch.allclose(x_aggr[2], j) assert torch.allclose(x_aggr[3], k) r1 = global_add_pool(x.r, idx) i1 = global_add_pool(x.i, idx) j1 = global_add_pool(x.j, idx) k1 = global_add_pool(x.k, idx) q_aggr3 = QTensor(r1, i1, j1, k1) assert q_aggr == q_aggr2 == q_aggr3
def test_simple_mul_left_right(self): t1 = QTensor(r=torch.tensor([1.0, 2.0, 3.0, -1.0]), i=torch.tensor([2.0, 2.0, 3.0, -2.0]), j=torch.tensor([2.0, 1.0, 0.0, 1.5]), k=torch.tensor([5.0, 4.0, 3.0, 5.0])).requires_grad_() t2 = torch.tensor([5.0, -1.0, 2.0, -2.0]) t3_left = t1 * t2 t3_right = t2 * t1 assert t3_left == t3_right t1 = QTensor(r=torch.tensor([1.0, 2.0, 3.0, -1.0]), i=torch.tensor([2.0, 2.0, 3.0, -2.0]), j=torch.tensor([2.0, 1.0, 0.0, 1.5]), k=torch.tensor([5.0, 4.0, 3.0, 5.0])).requires_grad_() t1 *= 2.0 t3 = QTensor(r=torch.tensor([2.0, 4.0, 6.0, -2.0]), i=torch.tensor([4.0, 4.0, 6.0, -4.0]), j=torch.tensor([4.0, 2.0, 0.0, 3.0]), k=torch.tensor([10.0, 8.0, 6.0, 10.0])) res = t1 - t3 res0 = QTensor.zeros(4) assert res == res0 dropout_mask = (torch.empty(16, 9).uniform_() > 0.1).float() q = QTensor(*torch.randn(4, 16, 9)) q_dropped0 = dropout_mask * q q_dropped1 = q * dropout_mask assert q_dropped0 == q_dropped1 r, i, j, k = q.r * dropout_mask, q.i * dropout_mask, q.j * dropout_mask, q.k * dropout_mask q_dropped_manually = QTensor(r, i, j, k) assert q_dropped0 == q_dropped1 == q_dropped_manually
def forward(self, x: torch.Tensor) -> QTensor: encoded = self.encoder(x) return QTensor(encoded.clone(), encoded.clone(), encoded.clone(), encoded.clone())
def forward(self, x: torch.Tensor) -> QTensor: return QTensor(self.r(x), self.i(x), self.j(x), self.k(x))
def forward(self, x: QTensor, idx: torch.Tensor, dim: int, dim_size: Optional[int] = None) -> QTensor: x_tensor = x.stack(dim=1) # transform to torch.Tensor (*, 4, *) aggr = torch_scatter.scatter(src=x_tensor, index=idx, dim=dim, dim_size=dim_size, reduce=self.reduce) aggr = aggr.permute(1, 0, 2) # permute such that first dimension is (4,*,*) return QTensor(*aggr)
def __call__(self, x: QTensor, batch: Batch) -> QTensor: x_tensor = x.stack(dim=1) # transform to torch.Tensor pooled = self.module(x=x_tensor, batch=batch) # apply global pooling pooled = pooled.permute(1, 0, 2) # permute such that first dimension is (4,*,*) return QTensor(*pooled)