def test_quaternion_batchnorm(self): batch_size = 128 in_channels = 16 q = QTensor(*torch.rand(4, batch_size, in_channels) * 2.0 + 5) # (128, 16) batch_norm = QuaternionNorm(num_features=in_channels, type="q-batch-norm", **{"affine": False}) batch_norm = batch_norm.train() y = batch_norm(q) # (128, 16) # each quaternion number has mean 0 and standard deviation 1 y_stacked = y.stack(dim=0) # (4, 128, 16) perm = y_stacked.permute(2, 0, 1) # [16, 4, 1] # covariances of shape (16, 4, 4) cov = torch.matmul(perm, perm.transpose( -1, -2)) / perm.shape[-1] # [16, 4, 4] eye_covs = torch.stack([torch.eye(4) for _ in range(cov.size(0))], dim=0) diffs = cov - eye_covs a = torch.abs(diffs).sum().item() assert 0.0001 < np.round(a, 4) < 0.05 # check mean assert abs(y_stacked.mean(1).norm().item()) < 1e-4
class QMLP(nn.Module): """ Implementing a 2-layer Quaternion Multilayer Perceptron """ def __init__(self, in_features: int, out_features: int, bias: bool = True, activation: str = "relu", norm: Union[None, str] = None, init: str = "orthogonal", factor: float = 1, **kwargs) -> None: super(QMLP, self).__init__() self.in_features = in_features self.out_features = out_features self.bias = bias self.activation_str = activation self.qlinear1 = QLinear(in_features=in_features, out_features=int(factor*out_features), bias=bias, init=init) self.qlinear2 = QLinear(in_features=int(factor*out_features), out_features=out_features, bias=bias, init=init) self.activation = get_functional_activation(activation) self.norm_type = norm self.factor = factor self.init_type = init if norm in ["naive-batch-norm", "q-batch-norm"]: self.norm_flag = True self.norm = QuaternionNorm(num_features=int(factor*out_features), type=norm, **kwargs) else: self.norm_flag = False self.reset_parameters() def reset_parameters(self): self.qlinear1.reset_parameters() self.qlinear2.reset_parameters() if self.norm_flag: self.norm.reset_parameters() def forward(self, q: QTensor): q = self.qlinear1(q) if self.norm_flag: q = self.norm(q) q = self.activation(q) q = self.qlinear2(q) return q def __repr__(self): return '{}(in_features={}, out_features={}, bias={}, ' \ 'activation="{}", norm="{}", ' \ 'init="{}", factor={})'.format(self.__class__.__name__, self.in_features, self.out_features, self.bias, self.activation_str, self.norm_type, self.init_type, self.factor)
def test_naive_batch_norm(self): batch_size = 128 in_channels = 16 q = QTensor(*torch.rand(4, batch_size, in_channels) * 2.0 + 5) # (128, 16) batch_norm = QuaternionNorm(num_features=in_channels, type="naive-batch-norm") batch_norm = batch_norm.train() y = batch_norm(q) # assert that r,i,j,k components all have mean 0 separately mean_r = torch.mean(y.r, dim=0) # (16,) mean_i = torch.mean(y.i, dim=0) mean_j = torch.mean(y.j, dim=0) mean_k = torch.mean(y.k, dim=0) assert abs(mean_r.norm().item()) < 1e-5 assert abs(mean_i.norm().item()) < 1e-5 assert abs(mean_j.norm().item()) < 1e-5 assert abs(mean_k.norm().item()) < 1e-5 # assert that r,i,j,k components all have (biased) standard deviation of 1 separately std_r = torch.std(y.r, dim=0, unbiased=False) # (16, ) std_i = torch.std(y.i, dim=0, unbiased=False) std_j = torch.std(y.j, dim=0, unbiased=False) std_k = torch.std(y.k, dim=0, unbiased=False) assert abs(std_r - 1.0).sum().item() < 0.001 assert abs(std_i - 1.0).sum().item() < 0.001 assert abs(std_j - 1.0).sum().item() < 0.001 assert abs(std_k - 1.0).sum().item() < 0.001 # what about the covariance between for each quaternion number? y_stacked = y.stack(dim=0) # (4, 128, 16) perm = y_stacked.permute(2, 0, 1) # [16, 4, 1] cov = torch.matmul(perm, perm.transpose( -1, -2)) / perm.shape[-1] # [16, 4, 4] eye_covs = torch.stack([torch.eye(4) for _ in range(cov.size(0))], dim=0) diffs = cov - eye_covs a = torch.abs(diffs).sum().item() assert np.round(a, 4) > 1.0
def test_naive_batchnorm(self): # Naive batch-normalization for quaternions and phm (with phm_dim=4) should be the same. phm_dim = 4 num_features = 16 batch_size = 128 quat_bn = QuaternionNorm(type="naive-batch-norm", num_features=num_features).train() phm_bn = PHMNorm(type="naive-batch-norm", phm_dim=phm_dim, num_features=num_features).train() distances = [] for i in range(500): x = torch.randn(phm_dim, batch_size, num_features) x_q = QTensor(*x) x_p = x.permute(1, 0, 2).reshape(batch_size, -1) y_quat = quat_bn(x_q) y_phm = phm_bn(x_p) y_quat = y_quat.stack(dim=1) y_quat = y_quat.reshape(batch_size, -1) d = (y_phm - y_quat).norm().item() distances.append(d) assert sum(distances) == 0.0 # eval mode uses running-mean and estimated weights + bias for rescaling and shifting. quat_bn = quat_bn.eval() phm_bn = phm_bn.eval() distances = [] for i in range(500): x = torch.randn(4, batch_size, num_features) x_q = QTensor(*x) x_p = x.permute(1, 0, 2).reshape(batch_size, -1) y_quat = quat_bn(x_q) y_phm = phm_bn(x_p) y_quat = y_quat.stack(dim=1) y_quat = y_quat.reshape(batch_size, -1) d = (y_phm - y_quat).norm().item() distances.append(d) assert sum(distances)/len(distances) < 1e-4
def __init__(self, in_features: int, out_features: int, bias: bool, init: str, activation: str, norm: Optional[str]) -> None: super(QLayerBlock, self).__init__() assert norm in [None, "naive-batch-norm", "q-batch-norm"] self.in_features = in_features self.out_features = out_features self.bias = bias self.init = init self.activation_str = activation self.activation = get_functional_activation(activation) self.norm = QuaternionNorm(num_features=out_features, type=norm) if norm is not None else None self.affine = QLinear(in_features=in_features, out_features=out_features, bias=bias, init=init) self.reset_parameters()
def __init__(self, in_features: int, out_features: int, bias: bool = True, activation: str = "relu", norm: Union[None, str] = None, init: str = "orthogonal", factor: float = 1, **kwargs) -> None: super(QMLP, self).__init__() self.in_features = in_features self.out_features = out_features self.bias = bias self.activation_str = activation self.qlinear1 = QLinear(in_features=in_features, out_features=int(factor*out_features), bias=bias, init=init) self.qlinear2 = QLinear(in_features=int(factor*out_features), out_features=out_features, bias=bias, init=init) self.activation = get_functional_activation(activation) self.norm_type = norm self.factor = factor self.init_type = init if norm in ["naive-batch-norm", "q-batch-norm"]: self.norm_flag = True self.norm = QuaternionNorm(num_features=int(factor*out_features), type=norm, **kwargs) else: self.norm_flag = False self.reset_parameters()
def __init__(self, in_features: int, hidden_layers: list, out_features: int, activation: str, bias: bool, norm: str, init: str, dropout: Union[float, list], same_dropout: bool = False, real_trafo: str = "linear") -> None: super(QuaternionDownstreamNet, self).__init__() self.in_features = in_features self.out_features = out_features self.hidden_layers = hidden_layers self.activation_str = activation self.activation_func = get_functional_activation(activation) self.init = init self.bias = bias self.dropout = [dropout] * len(hidden_layers) if isinstance( dropout, float) else dropout assert len(self.dropout) == len(self.hidden_layers), "dropout list must be of the same size " \ "as number of hidden layer" self.norm_type = norm self.same_dropout = same_dropout # affine linear layers # input -> first hidden layer self.affine = [ QLinear(in_features, self.hidden_layers[0], bias=bias, init=init) ] # hidden layers self.affine += [ QLinear(self.hidden_layers[i], self.hidden_layers[i + 1], bias=bias, init=init) for i in range(len(self.hidden_layers) - 1) ] # output layer self.affine += [ QLinear(self.hidden_layers[-1], self.out_features, init=init, bias=bias) ] self.affine = nn.ModuleList(self.affine) # transform the output quaternionic vector to real vector with Real_Transformer module self.real_trafo_type = real_trafo self.real_trafo = RealTransformer(type=self.real_trafo_type, in_features=self.out_features, bias=True) # normalizations self.norm_flag = False if "norm" in self.norm_type: norm_type = self.norm_type self.norm = [ QuaternionNorm(num_features=dim, type=norm_type) for dim in self.hidden_layers ] self.norm = nn.ModuleList(self.norm) self.norm_flag = True self.reset_parameters()
def __init__(self, atom_input_dims: Union[int, list] = ATOM_FEAT_DIMS, atom_encoded_dim: int = 128, bond_input_dims: Union[int, list] = BOND_FEAT_DIMS, naive_encoder: bool = False, init: str = "orthogonal", same_dropout: bool = False, mp_layers: list = [128, 196, 256], bias: bool = True, dropout_mpnn: list = [0.0, 0.0, 0.0], norm_mp: Optional[str] = "naive-batch-norm", add_self_loops: bool = True, msg_aggr: str = "add", node_aggr: str = "sum", mlp: bool = False, pooling: str = "softattention", activation: str = "relu", real_trafo: str = "linear", downstream_layers: list = [256, 128], target_dim: int = 1, dropout_dn: Union[list, float] = [0.2, 0.1], norm_dn: Optional[str] = "naive-batch-norm", msg_encoder: str = "identity", **kwargs) -> None: super(QuaternionSkipConnectConcat, self).__init__() assert activation.lower() in ["relu", "lrelu", "elu", "selu", "swish"] assert len(dropout_mpnn) == len(mp_layers) assert pooling in ["globalsum", "softattention" ], f"pooling variable '{pooling}' wrong." assert norm_mp in ["None", None, "naive-batch-norm", "q-batch-norm"] if msg_aggr == "sum": # for pytorch_geometrics MessagePassing class. msg_aggr = "add" self.msg_encoder_str = msg_encoder # save input args as attributes self.atom_input_dims = atom_input_dims self.bond_input_dims = bond_input_dims # one quaternion number consists of four components, so divide the feature dims by 4 atom_encoded_dim = atom_encoded_dim // 4 mp_layers = [dim // 4 for dim in mp_layers] downstream_layers = [dim // 4 for dim in downstream_layers] self.atom_encoded_dim = atom_encoded_dim self.naive_encoder = naive_encoder self.init = init self.same_dropout = same_dropout self.mp_layers = mp_layers self.bias = bias self.dropout_mpnn = dropout_mpnn self.norm_mp = norm_mp self.add_self_loops = add_self_loops self.msg_aggr_type = msg_aggr self.node_aggr_type = node_aggr self.mlp_mp = mlp self.pooling_type = pooling self.activation_str = activation self.real_trafo_type = real_trafo self.downstream_layers = downstream_layers self.target_dim = target_dim self.dropout_dn = dropout_dn self.norm_dn_type = norm_dn # define other attributes needed for module self.input_dim = atom_encoded_dim self.f_act = get_functional_activation(self.activation_str) # Quaternion MP layers self.convs = [None] * len(mp_layers) # batch normalization layers self.norms = [None] * len(mp_layers) dims = [atom_encoded_dim] + mp_layers # atom-encoder if not naive_encoder: self.atomencoder = QuaternionEncoder(out_dim=atom_encoded_dim, input_dims=atom_input_dims, combine="sum") else: self.atomencoder = NaiveQuaternionEncoder( out_dim=atom_encoded_dim, input_dims=atom_input_dims, combine="sum") # bond-encoder self.bondencoders = [] if not naive_encoder: module = QuaternionEncoder else: module = NaiveQuaternionEncoder for i in range(len(mp_layers)): if i == 0: out_dim = self.input_dim else: out_dim = self.mp_layers[i - 1] + self.input_dim self.bondencoders.append( module(input_dims=bond_input_dims, out_dim=out_dim, combine="sum")) self.bondencoders = nn.ModuleList(self.bondencoders) # prepare Quaternion MP layers and Norm if applicable for i in range(len(mp_layers)): if i == 0: in_dim = self.input_dim else: in_dim = self.mp_layers[i - 1] + self.input_dim out_dim = self.mp_layers[i] self.convs[i] = QMessagePassing(in_features=in_dim, out_features=out_dim, bias=bias, norm=norm_mp, activation=activation, init=init, aggr=msg_aggr, mlp=mlp, same_dim=False, add_self_loops=add_self_loops, msg_encoder=msg_encoder, **kwargs) if norm_mp: self.norms[i] = QuaternionNorm(num_features=out_dim, type=norm_mp) self.convs = nn.ModuleList(self.convs) if norm_mp: self.norms = nn.ModuleList(self.norms) if pooling == "globalsum": self.pooling = QuaternionGlobalSumPooling() else: self.pooling = QuaternionSoftAttentionPooling( embed_dim=self.mp_layers[-1] + self.input_dim, init=self.init, bias=self.bias, real_trafo=self.real_trafo_type) # downstream network self.downstream = QuaternionDownstreamNet( in_features=self.mp_layers[-1] + self.input_dim, hidden_layers=self.downstream_layers, out_features=self.target_dim, activation=self.activation_str, bias=self.bias, norm=self.norm_dn_type, init=self.init, dropout=self.dropout_dn, same_dropout=self.same_dropout, real_trafo=self.real_trafo_type) self.reset_parameters()