예제 #1
0
    def forward(self, features_1, features_2=None):
        '''
        :param features_1: [..., in1] or [..., in1, in2] if features_2 is None
        :param features_2: [..., in2]
        :return: [..., out]
        '''
        d_out = dim(self.Rs_out)
        d_in1 = dim(self.Rs_in1)
        d_in2 = dim(self.Rs_in2)

        if self._complete == 'out' or features_2 is None:
            if features_2 is None:
                features = features_1
            else:
                features = features_1[..., :, None] * features_2[..., None, :]

            size = features.shape[:-2]
            features = features.reshape(-1, d_in1, d_in2)  # [in1, in2, batch]

            mixing_matrix = get_sparse_buffer(
                self, "mixing_matrix")  # [out, in1 * in2]

            features = torch.einsum('zij->ijz', features)  # [in1, in2, batch]
            features = features.reshape(d_in1 * d_in2, features.shape[2])
            features = mixing_matrix @ features  # [out, batch]
            return features.T.reshape(*size, d_out)
        if self._complete == 'in1':
            k = self.left(features_1)  # [..., out, in2]
            return torch.einsum('...ij,...j->...i', k, features_2)
        if self._complete == 'in2':
            k = self.right(features_2)  # [..., out, in1]
            return torch.einsum('...ij,...j->...i', k, features_1)
예제 #2
0
    def forward(self, features):
        *size, n = features.size()
        features = features.reshape(-1, n)

        mixing_matrix = get_sparse_buffer(self, 'mixing_matrix')
        # features = torch.einsum('ij,zj->zi', self.mixing_matrix, features)
        features = (mixing_matrix @ features.T).T
        return features.reshape(*size, self.mul, -1)
예제 #3
0
 def to_dense(self):
     """
     :return: tensor of shape [dim(Rs_out), dim(Rs_in1), dim(Rs_in2)]
     """
     mixing_matrix = get_sparse_buffer(self, "mixing_matrix")  # [out, in1 * in2]
     mixing_matrix = mixing_matrix.to_dense()
     mixing_matrix = mixing_matrix.reshape(dim(self.Rs_out), dim(self.Rs_in1), dim(self.Rs_in2))
     return mixing_matrix
예제 #4
0
    def forward(self, features):
        '''
        :param features: [..., channels]
        '''
        *size, n = features.size()
        features = features.reshape(-1, n)

        mixing_matrix = get_sparse_buffer(self, "mixing_matrix")

        features = torch.einsum('zi,zj->ijz', features, features)
        features = mixing_matrix @ features.reshape(-1, features.shape[2])
        return features.T.reshape(*size, -1)
예제 #5
0
    def forward(self, features_1, features_2):
        """
        :return:         tensor [..., channel]
        """
        *size, n = features_1.size()
        features_1 = features_1.reshape(-1, n)
        assert n == rs.dim(self.Rs_in1)
        *size2, n = features_2.size()
        features_2 = features_2.reshape(-1, n)
        assert size == size2

        T = get_sparse_buffer(self, 'T')  # [out, in1 * in2]
        kernel = (T.t() @ self.kernel().T).T.reshape(rs.dim(self.Rs_out), rs.dim(self.Rs_in1), rs.dim(self.Rs_in2))  # [out, in1, in2]
        features = torch.einsum('kij,zi,zj->zk', kernel, features_1, features_2)
        return features.reshape(*size, -1)
예제 #6
0
    def right(self, features_2):
        '''
        :param features_2: [..., in2]
        :return: [..., out, in1]
        '''
        d_out = dim(self.Rs_out)
        d_in1 = dim(self.Rs_in1)
        d_in2 = dim(self.Rs_in2)
        size_2 = features_2.shape[:-1]
        features_2 = features_2.reshape(-1, d_in2)

        mixing_matrix = get_sparse_buffer(self, "mixing_matrix")  # [out, in1 * in2]
        mixing_matrix = mixing_matrix.sparse_reshape(d_out * d_in1, d_in2)
        output = mixing_matrix @ features_2.T  # [out * in1, batch]
        return output.T.reshape(*size_2, d_out, d_in1)
예제 #7
0
    def forward(self, features):
        '''
        :param features: [..., channels]
        '''
        *size, n = features.size()
        features = features.reshape(-1, n)
        assert n == rs.dim(self.Rs_in)

        if self.linear:
            features = torch.cat([features.new_ones(features.shape[0], 1), features], dim=1)
            n += 1

        T = get_sparse_buffer(self, 'T')  # [out, in1 * in2]
        kernel = (T.t() @ self.kernel().T).T.reshape(rs.dim(self.Rs_out), n, n)  # [out, in1, in2]
        features = torch.einsum('zi,zj->zij', features, features)
        features = torch.einsum('kij,zij->zk', kernel, features)
        return features.reshape(*size, -1)
예제 #8
0
    def left(self, features_1):
        '''
        :param features_1: [..., in1]
        :return: [..., out, in2]
        '''
        d_out = dim(self.Rs_out)
        d_in1 = dim(self.Rs_in1)
        d_in2 = dim(self.Rs_in2)
        size_1 = features_1.shape[:-1]
        features_1 = features_1.reshape(-1, d_in1)

        mixing_matrix = get_sparse_buffer(self, "mixing_matrix")  # [out, in1 * in2]
        mixing_matrix = mixing_matrix.sparse_reshape(d_out * d_in1, d_in2).t()  # [in2, out * in1]
        mixing_matrix = mixing_matrix.sparse_reshape(d_in2 * d_out, d_in1)  # [in2 * out, in1]
        output = mixing_matrix @ features_1.T  # [in2 * out, batch]
        output = output.reshape(d_in2, d_out, features_1.shape[0])
        output = torch.einsum('j*z->zij', output)
        return output.reshape(*size_1, d_out, d_in2)
예제 #9
0
    def forward(self, features_1, features_2):
        '''
        :param features_1: [..., in1]
        :param features_2: [..., in2]
        :return: [..., out]
        '''
        d_out = dim(self.Rs_out)
        d_in1 = dim(self.Rs_in1)
        d_in2 = dim(self.Rs_in2)

        features = features_1[..., :, None] * features_2[..., None, :]

        size = features.shape[:-2]
        features = features.reshape(-1, d_in1, d_in2)  # [in1, in2, batch]

        mixing_matrix = get_sparse_buffer(self, "mixing_matrix")  # [out, in1 * in2]

        features = torch.einsum('zij->ijz', features)  # [in1, in2, batch]
        features = features.reshape(d_in1 * d_in2, features.shape[2])
        features = mixing_matrix @ features  # [out, batch]
        return features.T.reshape(*size, d_out)