예제 #1
0
    def __init__(self,
                 num_freq,
                 in_dim,
                 out_dim,
                 edge_dim=None,
                 fourier_encode_dist=False,
                 num_fourier_features=4,
                 mid_dim=128):
        super().__init__()
        self.num_freq = num_freq
        self.in_dim = in_dim
        self.mid_dim = mid_dim
        self.out_dim = out_dim
        self.edge_dim = default(edge_dim, 0)

        self.fourier_encode_dist = fourier_encode_dist
        self.num_fourier_features = num_fourier_features if fourier_encode_dist else 0

        input_dim = self.edge_dim + 1 + (self.num_fourier_features * 2)

        self.net = nn.Sequential(
            nn.Linear(input_dim, mid_dim), nn.LayerNorm(mid_dim), nn.ReLU(),
            nn.Linear(mid_dim, mid_dim), nn.LayerNorm(mid_dim), nn.ReLU(),
            nn.Linear(mid_dim, num_freq * in_dim * out_dim))

        self.apply(self.init_)
    def __init__(
        self,
        num_freq,
        in_dim,
        out_dim,
        edge_dim = None,
        mid_dim = 128
    ):
        super().__init__()
        self.num_freq = num_freq
        self.in_dim = in_dim
        self.mid_dim = mid_dim
        self.out_dim = out_dim
        self.edge_dim = default(edge_dim, 0)

        self.net = nn.Sequential(
            nn.Linear(self.edge_dim + 1, mid_dim),
            nn.LayerNorm(mid_dim),
            nn.GELU(),
            nn.Linear(mid_dim, mid_dim),
            nn.LayerNorm(mid_dim),
            nn.GELU(),
            nn.Linear(mid_dim, num_freq * in_dim * out_dim)
        )

        self.apply(self.init_)
예제 #3
0
def irr_repr(order, alpha, beta, gamma, dtype = None):
    """
    irreducible representation of SO3
    - compatible with compose and spherical_harmonics
    """
    cast_ = cast_torch_tensor(lambda t: t)
    dtype = default(dtype, torch.get_default_dtype())
    alpha, beta, gamma = map(cast_, (alpha, beta, gamma))
    return wigner_d_matrix(order, alpha, beta, gamma, dtype = dtype)
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        dim_head = 24,
        depth = 2,
        input_degrees = 1,
        num_degrees = 2,
        output_degrees = 1,
        valid_radius = 1e5,
        reduce_dim_out = False,
        num_tokens = None,
        num_edge_tokens = None,
        edge_dim = None,
        reversible = False,
        attend_self = True,
        use_null_kv = False,
        differentiable_coors = False,
        fourier_encode_dist = False,
        rel_dist_num_fourier_features = 4,
        num_neighbors = float('inf'),
        attend_sparse_neighbors = False,
        num_adj_degrees = None,
        adj_dim = 0,
        max_sparse_neighbors = float('inf'),
        dim_in = None,
        dim_out = None,
        norm_out = False,
        num_conv_layers = 0,
        causal = False,
        splits = 4,
        global_feats_dim = None,
        linear_proj_keys = False,
        one_headed_key_values = False,
        tie_key_values = False,
        rotary_position = False,
        rotary_rel_dist = False
    ):
        super().__init__()
        dim_in = default(dim_in, dim)
        self.dim_in = cast_tuple(dim_in, input_degrees)
        self.dim = dim

        # token embedding

        self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None

        # positional embedding

        self.rotary_rel_dist = rotary_rel_dist
        self.rotary_position = rotary_position

        self.rotary_pos_emb = None
        if rotary_position or rotary_rel_dist:
            num_rotaries = int(rotary_position) + int(rotary_rel_dist)
            self.rotary_pos_emb = SinusoidalEmbeddings(dim_head // num_rotaries)

        # edges

        assert not (exists(num_edge_tokens) and not exists(edge_dim)), 'edge dimension (edge_dim) must be supplied if SE3 transformer is to have edge tokens'
        self.edge_emb = nn.Embedding(num_edge_tokens, edge_dim) if exists(num_edge_tokens) else None
        self.has_edges = exists(edge_dim) and edge_dim > 0

        self.input_degrees = input_degrees
        self.num_degrees = num_degrees
        self.output_degrees = output_degrees

        # whether to differentiate through basis, needed for alphafold2

        self.differentiable_coors = differentiable_coors

        # neighbors hyperparameters

        self.valid_radius = valid_radius
        self.num_neighbors = num_neighbors

        # sparse neighbors, derived from adjacency matrix or edges being passed in

        self.attend_sparse_neighbors = attend_sparse_neighbors
        self.max_sparse_neighbors = max_sparse_neighbors

        # adjacent neighbor derivation and embed

        assert not (exists(num_adj_degrees) and num_adj_degrees < 1), 'make sure adjacent degrees is greater than 1'
        self.num_adj_degrees = num_adj_degrees
        self.adj_emb = nn.Embedding(num_adj_degrees + 1, adj_dim) if exists(num_adj_degrees) and adj_dim > 0 else None

        edge_dim = (edge_dim if self.has_edges else 0) + (adj_dim if exists(self.adj_emb) else 0)

        # define fibers and dimensionality

        dim_in = default(dim_in, dim)
        dim_out = default(dim_out, dim)

        fiber_in     = Fiber.create(input_degrees, dim_in)
        fiber_hidden = Fiber.create(num_degrees, dim)
        fiber_out    = Fiber.create(output_degrees, dim_out)

        conv_kwargs = dict(edge_dim = edge_dim, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits)

        # causal

        assert not (causal and not attend_self), 'attending to self must be turned on if in autoregressive mode (for the first token)'
        self.causal = causal

        # main network

        self.conv_in  = ConvSE3(fiber_in, fiber_hidden, **conv_kwargs)

        # pre-convs

        self.convs = nn.ModuleList([])
        for _ in range(num_conv_layers):
            self.convs.append(nn.ModuleList([
                ConvSE3(fiber_hidden, fiber_hidden, **conv_kwargs),
                NormSE3(fiber_hidden)
            ]))

        # global features

        self.accept_global_feats = exists(global_feats_dim)
        assert not (reversible and self.accept_global_feats), 'reversibility and global features are not compatible'

        # trunk

        self.attend_self = attend_self

        attention_klass = OneHeadedKVAttentionSE3 if one_headed_key_values else AttentionSE3

        layers = nn.ModuleList([])
        for _ in range(depth):
            layers.append(nn.ModuleList([
                AttentionBlockSE3(fiber_hidden, heads = heads, dim_head = dim_head, attend_self = attend_self, edge_dim = edge_dim, fourier_encode_dist = fourier_encode_dist, rel_dist_num_fourier_features = rel_dist_num_fourier_features, use_null_kv = use_null_kv, splits = splits, global_feats_dim = global_feats_dim, linear_proj_keys = linear_proj_keys, attention_klass = attention_klass, tie_key_values = tie_key_values),
                FeedForwardBlockSE3(fiber_hidden)
            ]))

        execution_class = ReversibleSequence if reversible else SequentialSequence
        self.net = execution_class(layers)

        # out

        self.conv_out = ConvSE3(fiber_hidden, fiber_out, **conv_kwargs)

        self.norm = NormSE3(fiber_out, nonlin = nn.Identity()) if norm_out or reversible else nn.Identity()

        self.linear_out = LinearSE3(
            fiber_out,
            Fiber.create(output_degrees, 1)
        ) if reduce_dim_out else None
예제 #5
0
import os
from math import pi
import torch
from torch import einsum
from einops import rearrange
from itertools import product
from contextlib import contextmanager

from se3_transformer_pytorch.irr_repr import irr_repr, spherical_harmonics
from se3_transformer_pytorch.utils import torch_default_dtype, cache_dir, exists, default, to_order
from se3_transformer_pytorch.spherical_harmonics import clear_spherical_harmonics_cache

# constants

CACHE_PATH = default(os.getenv('CACHE_PATH'),
                     os.path.expanduser('~/.cache.equivariant_attention'))
CACHE_PATH = CACHE_PATH if not exists(os.environ.get('CLEAR_CACHE')) else None

# todo (figure ot why this was hard coded in official repo)

RANDOM_ANGLES = [[4.41301023, 5.56684102, 4.59384642],
                 [4.93325116, 6.12697327, 4.14574096],
                 [0.53878964, 4.09050444, 5.36539036],
                 [2.16017393, 3.48835314, 5.55174441],
                 [2.52385107, 0.2908958, 3.90040975]]

# helpers


@contextmanager
def null_context():