Exemplo n.º 1
0
    def test_quaternion_batchnorm(self):
        batch_size = 128
        in_channels = 16
        q = QTensor(*torch.rand(4, batch_size, in_channels) * 2.0 +
                    5)  # (128, 16)

        batch_norm = QuaternionNorm(num_features=in_channels,
                                    type="q-batch-norm",
                                    **{"affine": False})
        batch_norm = batch_norm.train()

        y = batch_norm(q)  # (128, 16)

        #  each quaternion number has mean 0 and standard deviation 1
        y_stacked = y.stack(dim=0)  # (4, 128, 16)
        perm = y_stacked.permute(2, 0, 1)  # [16, 4, 1]

        # covariances of shape (16, 4, 4)
        cov = torch.matmul(perm, perm.transpose(
            -1, -2)) / perm.shape[-1]  # [16, 4, 4]

        eye_covs = torch.stack([torch.eye(4) for _ in range(cov.size(0))],
                               dim=0)
        diffs = cov - eye_covs
        a = torch.abs(diffs).sum().item()
        assert 0.0001 < np.round(a, 4) < 0.05

        # check mean
        assert abs(y_stacked.mean(1).norm().item()) < 1e-4
Exemplo n.º 2
0
class QMLP(nn.Module):
    """ Implementing a 2-layer Quaternion Multilayer Perceptron """
    def __init__(self, in_features: int, out_features: int, bias: bool = True,
                 activation: str = "relu", norm: Union[None, str] = None,
                 init: str = "orthogonal", factor: float = 1, **kwargs) -> None:

        super(QMLP, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.bias = bias
        self.activation_str = activation
        self.qlinear1 = QLinear(in_features=in_features, out_features=int(factor*out_features),
                                bias=bias, init=init)
        self.qlinear2 = QLinear(in_features=int(factor*out_features), out_features=out_features,
                                bias=bias, init=init)
        self.activation = get_functional_activation(activation)
        self.norm_type = norm
        self.factor = factor
        self.init_type = init
        if norm in ["naive-batch-norm", "q-batch-norm"]:
            self.norm_flag = True
            self.norm = QuaternionNorm(num_features=int(factor*out_features), type=norm, **kwargs)
        else:
            self.norm_flag = False

        self.reset_parameters()


    def reset_parameters(self):
        self.qlinear1.reset_parameters()
        self.qlinear2.reset_parameters()
        if self.norm_flag:
            self.norm.reset_parameters()

    def forward(self, q: QTensor):
        q = self.qlinear1(q)
        if self.norm_flag:
            q = self.norm(q)
        q = self.activation(q)
        q = self.qlinear2(q)
        return q

    def __repr__(self):
        return '{}(in_features={}, out_features={}, bias={}, ' \
               'activation="{}", norm="{}", ' \
               'init="{}", factor={})'.format(self.__class__.__name__,
                                              self.in_features,
                                              self.out_features,
                                              self.bias,
                                              self.activation_str,
                                              self.norm_type, self.init_type,
                                              self.factor)
Exemplo n.º 3
0
    def test_naive_batch_norm(self):

        batch_size = 128
        in_channels = 16
        q = QTensor(*torch.rand(4, batch_size, in_channels) * 2.0 +
                    5)  # (128, 16)

        batch_norm = QuaternionNorm(num_features=in_channels,
                                    type="naive-batch-norm")
        batch_norm = batch_norm.train()

        y = batch_norm(q)

        # assert that r,i,j,k components all have mean 0 separately
        mean_r = torch.mean(y.r, dim=0)  # (16,)
        mean_i = torch.mean(y.i, dim=0)
        mean_j = torch.mean(y.j, dim=0)
        mean_k = torch.mean(y.k, dim=0)
        assert abs(mean_r.norm().item()) < 1e-5
        assert abs(mean_i.norm().item()) < 1e-5
        assert abs(mean_j.norm().item()) < 1e-5
        assert abs(mean_k.norm().item()) < 1e-5

        # assert that r,i,j,k components all have (biased) standard deviation of 1 separately
        std_r = torch.std(y.r, dim=0, unbiased=False)  # (16, )
        std_i = torch.std(y.i, dim=0, unbiased=False)
        std_j = torch.std(y.j, dim=0, unbiased=False)
        std_k = torch.std(y.k, dim=0, unbiased=False)
        assert abs(std_r - 1.0).sum().item() < 0.001
        assert abs(std_i - 1.0).sum().item() < 0.001
        assert abs(std_j - 1.0).sum().item() < 0.001
        assert abs(std_k - 1.0).sum().item() < 0.001

        # what about the covariance between for each quaternion number?
        y_stacked = y.stack(dim=0)  # (4, 128, 16)
        perm = y_stacked.permute(2, 0, 1)  # [16, 4, 1]
        cov = torch.matmul(perm, perm.transpose(
            -1, -2)) / perm.shape[-1]  # [16, 4, 4]

        eye_covs = torch.stack([torch.eye(4) for _ in range(cov.size(0))],
                               dim=0)
        diffs = cov - eye_covs
        a = torch.abs(diffs).sum().item()

        assert np.round(a, 4) > 1.0
    def test_naive_batchnorm(self):


        # Naive batch-normalization for quaternions and phm (with phm_dim=4) should be the same.

        phm_dim = 4
        num_features = 16
        batch_size = 128
        quat_bn = QuaternionNorm(type="naive-batch-norm", num_features=num_features).train()
        phm_bn = PHMNorm(type="naive-batch-norm", phm_dim=phm_dim, num_features=num_features).train()

        distances = []
        for i in range(500):
            x = torch.randn(phm_dim, batch_size, num_features)
            x_q = QTensor(*x)
            x_p = x.permute(1, 0, 2).reshape(batch_size, -1)
            y_quat = quat_bn(x_q)
            y_phm = phm_bn(x_p)
            y_quat = y_quat.stack(dim=1)
            y_quat = y_quat.reshape(batch_size, -1)
            d = (y_phm - y_quat).norm().item()
            distances.append(d)

        assert sum(distances) == 0.0

        # eval mode uses running-mean and estimated weights + bias for rescaling and shifting.
        quat_bn = quat_bn.eval()
        phm_bn = phm_bn.eval()

        distances = []
        for i in range(500):
            x = torch.randn(4, batch_size, num_features)
            x_q = QTensor(*x)
            x_p = x.permute(1, 0, 2).reshape(batch_size, -1)
            y_quat = quat_bn(x_q)
            y_phm = phm_bn(x_p)
            y_quat = y_quat.stack(dim=1)
            y_quat = y_quat.reshape(batch_size, -1)
            d = (y_phm - y_quat).norm().item()
            distances.append(d)

        assert sum(distances)/len(distances) < 1e-4
Exemplo n.º 5
0
    def __init__(self, in_features: int, out_features: int, bias: bool, init: str,
                 activation: str, norm: Optional[str]) -> None:
        super(QLayerBlock, self).__init__()
        assert norm in [None, "naive-batch-norm", "q-batch-norm"]
        self.in_features = in_features
        self.out_features = out_features
        self.bias = bias
        self.init = init
        self.activation_str = activation
        self.activation = get_functional_activation(activation)
        self.norm = QuaternionNorm(num_features=out_features, type=norm) if norm is not None else None
        self.affine = QLinear(in_features=in_features, out_features=out_features, bias=bias, init=init)

        self.reset_parameters()
Exemplo n.º 6
0
    def __init__(self, in_features: int, out_features: int, bias: bool = True,
                 activation: str = "relu", norm: Union[None, str] = None,
                 init: str = "orthogonal", factor: float = 1, **kwargs) -> None:

        super(QMLP, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.bias = bias
        self.activation_str = activation
        self.qlinear1 = QLinear(in_features=in_features, out_features=int(factor*out_features),
                                bias=bias, init=init)
        self.qlinear2 = QLinear(in_features=int(factor*out_features), out_features=out_features,
                                bias=bias, init=init)
        self.activation = get_functional_activation(activation)
        self.norm_type = norm
        self.factor = factor
        self.init_type = init
        if norm in ["naive-batch-norm", "q-batch-norm"]:
            self.norm_flag = True
            self.norm = QuaternionNorm(num_features=int(factor*out_features), type=norm, **kwargs)
        else:
            self.norm_flag = False

        self.reset_parameters()
Exemplo n.º 7
0
    def __init__(self,
                 in_features: int,
                 hidden_layers: list,
                 out_features: int,
                 activation: str,
                 bias: bool,
                 norm: str,
                 init: str,
                 dropout: Union[float, list],
                 same_dropout: bool = False,
                 real_trafo: str = "linear") -> None:

        super(QuaternionDownstreamNet, self).__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.hidden_layers = hidden_layers
        self.activation_str = activation
        self.activation_func = get_functional_activation(activation)
        self.init = init
        self.bias = bias
        self.dropout = [dropout] * len(hidden_layers) if isinstance(
            dropout, float) else dropout
        assert len(self.dropout) == len(self.hidden_layers), "dropout list must be of the same size " \
                                                             "as number of hidden layer"
        self.norm_type = norm
        self.same_dropout = same_dropout

        # affine linear layers
        # input -> first hidden layer
        self.affine = [
            QLinear(in_features, self.hidden_layers[0], bias=bias, init=init)
        ]
        # hidden layers
        self.affine += [
            QLinear(self.hidden_layers[i],
                    self.hidden_layers[i + 1],
                    bias=bias,
                    init=init) for i in range(len(self.hidden_layers) - 1)
        ]
        # output layer
        self.affine += [
            QLinear(self.hidden_layers[-1],
                    self.out_features,
                    init=init,
                    bias=bias)
        ]

        self.affine = nn.ModuleList(self.affine)

        # transform the output quaternionic vector to real vector with Real_Transformer module
        self.real_trafo_type = real_trafo
        self.real_trafo = RealTransformer(type=self.real_trafo_type,
                                          in_features=self.out_features,
                                          bias=True)

        # normalizations
        self.norm_flag = False
        if "norm" in self.norm_type:
            norm_type = self.norm_type
            self.norm = [
                QuaternionNorm(num_features=dim, type=norm_type)
                for dim in self.hidden_layers
            ]
            self.norm = nn.ModuleList(self.norm)
            self.norm_flag = True

        self.reset_parameters()
Exemplo n.º 8
0
    def __init__(self,
                 atom_input_dims: Union[int, list] = ATOM_FEAT_DIMS,
                 atom_encoded_dim: int = 128,
                 bond_input_dims: Union[int, list] = BOND_FEAT_DIMS,
                 naive_encoder: bool = False,
                 init: str = "orthogonal",
                 same_dropout: bool = False,
                 mp_layers: list = [128, 196, 256],
                 bias: bool = True,
                 dropout_mpnn: list = [0.0, 0.0, 0.0],
                 norm_mp: Optional[str] = "naive-batch-norm",
                 add_self_loops: bool = True,
                 msg_aggr: str = "add",
                 node_aggr: str = "sum",
                 mlp: bool = False,
                 pooling: str = "softattention",
                 activation: str = "relu",
                 real_trafo: str = "linear",
                 downstream_layers: list = [256, 128],
                 target_dim: int = 1,
                 dropout_dn: Union[list, float] = [0.2, 0.1],
                 norm_dn: Optional[str] = "naive-batch-norm",
                 msg_encoder: str = "identity",
                 **kwargs) -> None:
        super(QuaternionSkipConnectConcat, self).__init__()

        assert activation.lower() in ["relu", "lrelu", "elu", "selu", "swish"]
        assert len(dropout_mpnn) == len(mp_layers)
        assert pooling in ["globalsum", "softattention"
                           ], f"pooling variable '{pooling}' wrong."
        assert norm_mp in ["None", None, "naive-batch-norm", "q-batch-norm"]

        if msg_aggr == "sum":  # for pytorch_geometrics MessagePassing class.
            msg_aggr = "add"

        self.msg_encoder_str = msg_encoder
        # save input args as attributes
        self.atom_input_dims = atom_input_dims
        self.bond_input_dims = bond_input_dims

        # one quaternion number consists of four components, so divide the feature dims by 4
        atom_encoded_dim = atom_encoded_dim // 4
        mp_layers = [dim // 4 for dim in mp_layers]
        downstream_layers = [dim // 4 for dim in downstream_layers]

        self.atom_encoded_dim = atom_encoded_dim
        self.naive_encoder = naive_encoder
        self.init = init
        self.same_dropout = same_dropout
        self.mp_layers = mp_layers
        self.bias = bias
        self.dropout_mpnn = dropout_mpnn
        self.norm_mp = norm_mp
        self.add_self_loops = add_self_loops
        self.msg_aggr_type = msg_aggr
        self.node_aggr_type = node_aggr
        self.mlp_mp = mlp
        self.pooling_type = pooling
        self.activation_str = activation
        self.real_trafo_type = real_trafo
        self.downstream_layers = downstream_layers
        self.target_dim = target_dim
        self.dropout_dn = dropout_dn
        self.norm_dn_type = norm_dn

        # define other attributes needed for module
        self.input_dim = atom_encoded_dim
        self.f_act = get_functional_activation(self.activation_str)
        # Quaternion MP layers
        self.convs = [None] * len(mp_layers)
        # batch normalization layers
        self.norms = [None] * len(mp_layers)

        dims = [atom_encoded_dim] + mp_layers
        # atom-encoder
        if not naive_encoder:
            self.atomencoder = QuaternionEncoder(out_dim=atom_encoded_dim,
                                                 input_dims=atom_input_dims,
                                                 combine="sum")
        else:
            self.atomencoder = NaiveQuaternionEncoder(
                out_dim=atom_encoded_dim,
                input_dims=atom_input_dims,
                combine="sum")

        # bond-encoder
        self.bondencoders = []
        if not naive_encoder:
            module = QuaternionEncoder
        else:
            module = NaiveQuaternionEncoder
        for i in range(len(mp_layers)):
            if i == 0:
                out_dim = self.input_dim
            else:
                out_dim = self.mp_layers[i - 1] + self.input_dim

            self.bondencoders.append(
                module(input_dims=bond_input_dims,
                       out_dim=out_dim,
                       combine="sum"))

        self.bondencoders = nn.ModuleList(self.bondencoders)

        # prepare Quaternion MP layers and Norm if applicable
        for i in range(len(mp_layers)):
            if i == 0:
                in_dim = self.input_dim
            else:
                in_dim = self.mp_layers[i - 1] + self.input_dim
            out_dim = self.mp_layers[i]
            self.convs[i] = QMessagePassing(in_features=in_dim,
                                            out_features=out_dim,
                                            bias=bias,
                                            norm=norm_mp,
                                            activation=activation,
                                            init=init,
                                            aggr=msg_aggr,
                                            mlp=mlp,
                                            same_dim=False,
                                            add_self_loops=add_self_loops,
                                            msg_encoder=msg_encoder,
                                            **kwargs)

            if norm_mp:
                self.norms[i] = QuaternionNorm(num_features=out_dim,
                                               type=norm_mp)

        self.convs = nn.ModuleList(self.convs)
        if norm_mp:
            self.norms = nn.ModuleList(self.norms)

        if pooling == "globalsum":
            self.pooling = QuaternionGlobalSumPooling()
        else:
            self.pooling = QuaternionSoftAttentionPooling(
                embed_dim=self.mp_layers[-1] + self.input_dim,
                init=self.init,
                bias=self.bias,
                real_trafo=self.real_trafo_type)

        # downstream network
        self.downstream = QuaternionDownstreamNet(
            in_features=self.mp_layers[-1] + self.input_dim,
            hidden_layers=self.downstream_layers,
            out_features=self.target_dim,
            activation=self.activation_str,
            bias=self.bias,
            norm=self.norm_dn_type,
            init=self.init,
            dropout=self.dropout_dn,
            same_dropout=self.same_dropout,
            real_trafo=self.real_trafo_type)
        self.reset_parameters()