Exemplo n.º 1
0
    def forward(self, inp, edge_info, rel_dist=None, basis=None):
        neighbor_indices, neighbor_masks, edges = edge_info
        rel_dist = rearrange(rel_dist, 'b m n -> b m n ()')

        kernels = {}
        outputs = {}

        for (di, mi), (do, mo) in (self.fiber_in * self.fiber_out):
            etype = f'({di},{do})'
            kernel_fn = self.kernel_unary[etype]
            edge_features = torch.cat(
                (rel_dist, edges), dim=-1) if exists(edges) else rel_dist
            kernels[etype] = kernel_fn(edge_features, basis=basis)

        for degree_out in self.fiber_out.degrees:
            output = 0
            degree_out_key = str(degree_out)

            for degree_in, m_in in self.fiber_in:
                x = inp[str(degree_in)]
                x = batched_index_select(x, neighbor_indices, dim=1)
                x = x.view(*x.shape[:3], to_order(degree_in) * m_in, 1)

                etype = f'({degree_in},{degree_out})'
                kernel = kernels[etype]
                output = output + einsum('... o i, ... i c -> ... o c', kernel,
                                         x)

            if self.pool:
                output = masked_mean(
                    output, neighbor_masks,
                    dim=2) if exists(neighbor_masks) else output.mean(dim=2)

            leading_shape = x.shape[:2] if self.pool else x.shape[:3]
            output = output.view(*leading_shape, -1, to_order(degree_out))

            outputs[degree_out_key] = output

        if self.self_interaction:
            self_interact_out = self.self_interact(inp)
            outputs = self.self_interact_sum(outputs, self_interact_out)

        return outputs
Exemplo n.º 2
0
    def forward(self,
                feats,
                coors,
                mask=None,
                adj_mat=None,
                edges=None,
                return_type=None,
                return_pooled=False):
        _mask = mask

        if self.output_degrees == 1:
            return_type = 0

        if exists(self.token_emb):
            feats = self.token_emb(feats)

        assert not (
            self.attend_sparse_neighbors and not exists(adj_mat)
        ), 'adjacency matrix (adjacency_mat) or edges (edges) must be passed in'
        assert not (
            exists(edges) and not exists(self.edge_emb)
        ), 'edge embedding (num_edge_tokens & edge_dim) must be supplied if one were to train on edge types'

        if torch.is_tensor(feats):
            feats = {'0': feats[..., None]}

        b, n, d, *_, device = *feats['0'].shape, feats['0'].device

        assert d == self.dim, f'feature dimension {d} must be equal to dimension given at init {self.dim}'
        assert set(map(int, feats.keys())) == set(
            range(self.input_degrees
                  )), f'input must have {self.input_degrees} degree'

        num_degrees, neighbors, max_sparse_neighbors, valid_radius = self.num_degrees, self.num_neighbors, self.max_sparse_neighbors, self.valid_radius

        assert self.attend_sparse_neighbors or neighbors > 0, 'you must either attend to sparsely bonded neighbors, or set number of locally attended neighbors to be greater than 0'

        # create N-degrees adjacent matrix from 1st degree connections

        if exists(self.num_adj_degrees):
            if len(adj_mat.shape) == 2:
                adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b=b)

            adj_indices = adj_mat.clone().long()

            for ind in range(self.num_adj_degrees - 1):
                degree = ind + 2

                next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0
                next_degree_mask = (next_degree_adj_mat.float() -
                                    adj_mat.float()).bool()
                adj_indices.masked_fill_(next_degree_mask, degree)
                adj_mat = next_degree_adj_mat.clone()

        # se3 transformer by default cannot have a node attend to itself

        exclude_self_mask = rearrange(
            ~torch.eye(n, dtype=torch.bool, device=device), 'i j -> () i j')

        # calculate sparsely connected neighbors

        sparse_neighbor_mask = None
        num_sparse_neighbors = 0

        if self.attend_sparse_neighbors:
            assert exists(
                adj_mat
            ), 'adjacency matrix must be passed in (keyword argument adj_mat)'

            if exists(adj_mat):
                if len(adj_mat) == 2:
                    adj_mat = repeat(adj_mat, 'i j -> b i j', b=b)

            adj_mat = adj_mat.masked_select(exclude_self_mask).reshape(
                b, n, n - 1)

            adj_mat_values = adj_mat.float()
            adj_mat_max_neighbors = adj_mat_values.sum(dim=-1).max().item()

            if max_sparse_neighbors < adj_mat_max_neighbors:
                noise = torch.empty_like(adj_mat_values).uniform_(-0.01, 0.01)
                adj_mat_values += noise

            num_sparse_neighbors = int(
                min(max_sparse_neighbors, adj_mat_max_neighbors))
            values, indices = adj_mat_values.topk(num_sparse_neighbors, dim=-1)
            sparse_neighbor_mask = torch.zeros_like(adj_mat_values).scatter_(
                -1, indices, values)
            sparse_neighbor_mask = sparse_neighbor_mask > 0.5

        # exclude edge of token to itself

        indices = repeat(torch.arange(n, device=device),
                         'i -> b i j',
                         b=b,
                         j=n)
        rel_pos = rearrange(coors, 'b n d -> b n () d') - rearrange(
            coors, 'b n d -> b () n d')

        indices = indices.masked_select(exclude_self_mask).reshape(b, n, n - 1)
        rel_pos = rel_pos.masked_select(exclude_self_mask[..., None]).reshape(
            b, n, n - 1, 3)

        if exists(mask):
            mask = rearrange(mask, 'b i -> b i ()') * rearrange(
                mask, 'b j -> b () j')
            mask = mask.masked_select(exclude_self_mask).reshape(b, n, n - 1)

        if exists(edges):
            edges = self.edge_emb(edges)
            edges = edges.masked_select(exclude_self_mask[..., None]).reshape(
                b, n, n - 1, -1)

        if exists(self.adj_emb):
            adj_emb = self.adj_emb(adj_indices)
            edges = torch.cat(
                (edges, adj_emb), dim=-1) if exists(edges) else adj_emb

        rel_dist = rel_pos.norm(dim=-1)

        # use sparse neighbor mask to assign priority of bonded

        modified_rel_dist = rel_dist
        if exists(sparse_neighbor_mask):
            modified_rel_dist.masked_fill_(sparse_neighbor_mask, 0.)

        # if number of local neighbors by distance is set to 0, then only fetch the sparse neighbors defined by adjacency matrix

        if neighbors == 0:
            valid_radius = 0

        # get neighbors and neighbor mask, excluding self

        neighbors = int(min(neighbors, n - 1))
        total_neighbors = int(neighbors + num_sparse_neighbors)
        assert total_neighbors > 0, 'you must be fetching at least 1 neighbor'

        total_neighbors = int(
            min(total_neighbors, n - 1)
        )  # make sure total neighbors does not exceed the length of the sequence itself

        dist_values, nearest_indices = modified_rel_dist.topk(total_neighbors,
                                                              dim=-1,
                                                              largest=False)
        neighbor_mask = dist_values <= valid_radius

        neighbor_rel_dist = batched_index_select(rel_dist,
                                                 nearest_indices,
                                                 dim=2)
        neighbor_rel_pos = batched_index_select(rel_pos,
                                                nearest_indices,
                                                dim=2)
        neighbor_indices = batched_index_select(indices,
                                                nearest_indices,
                                                dim=2)

        if exists(mask):
            neighbor_mask = neighbor_mask & batched_index_select(
                mask, nearest_indices, dim=2)

        if exists(edges):
            edges = batched_index_select(edges, nearest_indices, dim=2)

        # calculate basis

        basis = get_basis(neighbor_rel_pos,
                          num_degrees - 1,
                          differentiable=self.differentiable_coors)

        # main logic

        edge_info = (neighbor_indices, neighbor_mask, edges)
        x = feats

        # project in

        x = self.conv_in(x, edge_info, rel_dist=neighbor_rel_dist, basis=basis)

        # transformer layers

        x = self.net(x,
                     edge_info=edge_info,
                     rel_dist=neighbor_rel_dist,
                     basis=basis)

        # project out

        x = self.conv_out(x,
                          edge_info,
                          rel_dist=neighbor_rel_dist,
                          basis=basis)

        # norm

        x = self.norm(x)

        # reduce dim if specified

        if exists(self.linear_out):
            x = self.linear_out(x)
            x = map_values(lambda t: t.squeeze(dim=2), x)

        if return_pooled:
            mask_fn = (lambda t: masked_mean(t, _mask, dim=1)
                       ) if exists(_mask) else (lambda t: t.mean(dim=1))
            x = map_values(mask_fn, x)

        if '0' in x:
            x['0'] = x['0'].squeeze(dim=-1)

        if exists(return_type):
            return x[str(return_type)]

        return x
Exemplo n.º 3
0
    def __init__(self,
                 *,
                 dim,
                 heads=8,
                 dim_head=64,
                 depth=2,
                 input_degrees=1,
                 num_degrees=2,
                 output_degrees=1,
                 valid_radius=1e5,
                 reduce_dim_out=False,
                 num_tokens=None,
                 num_edge_tokens=None,
                 edge_dim=None,
                 reversible=False,
                 attend_self=True,
                 use_null_kv=False,
                 differentiable_coors=False,
                 fourier_encode_dist=False,
                 num_neighbors=float('inf'),
                 attend_sparse_neighbors=False,
                 num_adj_degrees=None,
                 adj_dim=0,
                 max_sparse_neighbors=float('inf')):
        super().__init__()
        self.dim = dim

        self.token_emb = None
        self.token_emb = nn.Embedding(num_tokens,
                                      dim) if exists(num_tokens) else None

        assert not (
            exists(num_edge_tokens) and not exists(edge_dim)
        ), 'edge dimension (edge_dim) must be supplied if SE3 transformer is to have edge tokens'
        self.edge_emb = nn.Embedding(
            num_edge_tokens, edge_dim) if exists(num_edge_tokens) else None

        self.input_degrees = input_degrees
        self.num_degrees = num_degrees
        self.output_degrees = output_degrees

        # whether to differentiate through basis, needed for alphafold2

        self.differentiable_coors = differentiable_coors

        # neighbors hyperparameters

        self.valid_radius = valid_radius
        self.num_neighbors = num_neighbors

        # sparse neighbors, derived from adjacency matrix or edges being passed in

        self.attend_sparse_neighbors = attend_sparse_neighbors
        self.max_sparse_neighbors = max_sparse_neighbors

        # adjacent neighbor derivation and embed

        assert not (exists(num_adj_degrees) and num_adj_degrees < 1
                    ), 'make sure adjacent degrees is greater than 1'
        self.num_adj_degrees = num_adj_degrees
        self.adj_emb = nn.Embedding(
            num_adj_degrees +
            1, adj_dim) if exists(num_adj_degrees) and adj_dim > 0 else None

        edge_dim = (edge_dim if exists(self.edge_emb) else
                    0) + (adj_dim if exists(self.adj_emb) else 0)

        # main network

        fiber_in = Fiber.create(input_degrees, dim)
        fiber_hidden = Fiber.create(num_degrees, dim)
        fiber_out = Fiber.create(output_degrees, dim)

        self.conv_in = ConvSE3(fiber_in,
                               fiber_hidden,
                               edge_dim=edge_dim,
                               fourier_encode_dist=fourier_encode_dist)

        layers = nn.ModuleList([])
        for _ in range(depth):
            layers.append(
                nn.ModuleList([
                    AttentionBlockSE3(fiber_hidden,
                                      heads=heads,
                                      dim_head=dim_head,
                                      attend_self=attend_self,
                                      edge_dim=edge_dim,
                                      fourier_encode_dist=fourier_encode_dist,
                                      use_null_kv=use_null_kv),
                    FeedForwardBlockSE3(fiber_hidden)
                ]))

        execution_class = ReversibleSequence if reversible else SequentialSequence
        self.net = execution_class(layers)

        self.conv_out = ConvSE3(fiber_hidden,
                                fiber_out,
                                edge_dim=edge_dim,
                                fourier_encode_dist=fourier_encode_dist)

        self.norm = NormSE3(fiber_out)

        self.linear_out = LinearSE3(fiber_out, Fiber.create(
            output_degrees, 1)) if reduce_dim_out else None
Exemplo n.º 4
0
    def forward(self, features, edge_info, rel_dist, basis):
        h, attend_self = self.heads, self.attend_self
        device, dtype = get_tensor_device_and_dtype(features)
        neighbor_indices, neighbor_mask, edges = edge_info

        max_neg_value = -torch.finfo().max

        if exists(neighbor_mask):
            neighbor_mask = rearrange(neighbor_mask, 'b i j -> b () i j')

        neighbor_indices = rearrange(neighbor_indices, 'b i j -> b () i j')

        queries = self.to_q(features)
        keys, values = self.to_k(features, edge_info, rel_dist,
                                 basis), self.to_v(features, edge_info,
                                                   rel_dist, basis)

        if attend_self:
            self_keys, self_values = self.to_self_k(features), self.to_self_v(
                features)

        outputs = {}
        for degree in features.keys():
            q, k, v = map(lambda t: t[degree], (queries, keys, values))

            q = rearrange(q, 'b i (h d) m -> b h i d m', h=h)
            k, v = map(
                lambda t: rearrange(t, 'b i j (h d) m -> b h i j d m', h=h),
                (k, v))

            if self.use_null_kv:
                null_k, null_v = map(lambda t: t[degree],
                                     (self.null_keys, self.null_values))
                null_k, null_v = map(
                    lambda t: repeat(
                        t, 'h d m -> b h i () d m', b=q.shape[0], i=q.shape[
                            2]), (null_k, null_v))
                k = torch.cat((null_k, k), dim=3)
                v = torch.cat((null_v, v), dim=3)

            if attend_self:
                self_k, self_v = map(lambda t: t[degree],
                                     (self_keys, self_values))
                self_k, self_v = map(
                    lambda t: rearrange(t, 'b n (h d) m -> b h n () d m', h=h),
                    (self_k, self_v))
                k = torch.cat((self_k, k), dim=3)
                v = torch.cat((self_v, v), dim=3)

            sim = einsum('b h i d m, b h i j d m -> b h i j', q,
                         k) * self.scale

            if exists(neighbor_mask):
                num_left_pad = int(attend_self) + int(self.use_null_kv)
                mask = F.pad(neighbor_mask, (num_left_pad, 0), value=True)
                sim.masked_fill_(~mask, max_neg_value)

            attn = sim.softmax(dim=-1)
            out = einsum('b h i j, b h i j d m -> b h i d m', attn, v)
            outputs[degree] = rearrange(out, 'b h n d m -> b n (h d) m')

        return self.to_out(outputs)
    def forward(
        self,
        feats,
        coors,
        mask = None,
        adj_mat = None,
        edges = None,
        return_type = None,
        return_pooled = False,
        neighbor_mask = None,
        global_feats = None
    ):
        assert not (self.accept_global_feats ^ exists(global_feats)), 'you cannot pass in global features unless you init the class correctly'

        _mask = mask

        if self.output_degrees == 1:
            return_type = 0

        if exists(self.token_emb):
            feats = self.token_emb(feats)

        assert not (self.attend_sparse_neighbors and not exists(adj_mat)), 'adjacency matrix (adjacency_mat) or edges (edges) must be passed in'
        assert not (self.has_edges and not exists(edges)), 'edge embedding (num_edge_tokens & edge_dim) must be supplied if one were to train on edge types'

        if torch.is_tensor(feats):
            feats = {'0': feats[..., None]}

        if torch.is_tensor(global_feats):
            global_feats = {'0': global_feats[..., None]}

        b, n, d, *_, device = *feats['0'].shape, feats['0'].device

        assert d == self.dim_in[0], f'feature dimension {d} must be equal to dimension given at init {self.dim_in[0]}'
        assert set(map(int, feats.keys())) == set(range(self.input_degrees)), f'input must have {self.input_degrees} degree'

        num_degrees, neighbors, max_sparse_neighbors, valid_radius = self.num_degrees, self.num_neighbors, self.max_sparse_neighbors, self.valid_radius

        assert self.attend_sparse_neighbors or neighbors > 0, 'you must either attend to sparsely bonded neighbors, or set number of locally attended neighbors to be greater than 0'

        # se3 transformer by default cannot have a node attend to itself

        exclude_self_mask = rearrange(~torch.eye(n, dtype = torch.bool, device = device), 'i j -> () i j')
        remove_self = lambda t: t.masked_select(exclude_self_mask).reshape(b, n, n - 1)
        get_max_value = lambda t: torch.finfo(t.dtype).max

        # create N-degrees adjacent matrix from 1st degree connections

        if exists(self.num_adj_degrees):
            if len(adj_mat.shape) == 2:
                adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b = b)

            adj_indices = adj_mat.clone().long()

            for ind in range(self.num_adj_degrees - 1):
                degree = ind + 2

                next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0
                next_degree_mask = (next_degree_adj_mat.float() - adj_mat.float()).bool()
                adj_indices.masked_fill_(next_degree_mask, degree)
                adj_mat = next_degree_adj_mat.clone()

            adj_indices = adj_indices.masked_select(exclude_self_mask).reshape(b, n, n - 1)

        # calculate sparsely connected neighbors

        sparse_neighbor_mask = None
        num_sparse_neighbors = 0

        if self.attend_sparse_neighbors:
            assert exists(adj_mat), 'adjacency matrix must be passed in (keyword argument adj_mat)'

            if exists(adj_mat):
                if len(adj_mat.shape) == 2:
                    adj_mat = repeat(adj_mat, 'i j -> b i j', b = b)

            adj_mat = remove_self(adj_mat)

            adj_mat_values = adj_mat.float()
            adj_mat_max_neighbors = adj_mat_values.sum(dim = -1).max().item()

            if max_sparse_neighbors < adj_mat_max_neighbors:
                noise = torch.empty_like(adj_mat_values).uniform_(-0.01, 0.01)
                adj_mat_values += noise

            num_sparse_neighbors = int(min(max_sparse_neighbors, adj_mat_max_neighbors))
            values, indices = adj_mat_values.topk(num_sparse_neighbors, dim = -1)
            sparse_neighbor_mask = torch.zeros_like(adj_mat_values).scatter_(-1, indices, values)
            sparse_neighbor_mask = sparse_neighbor_mask > 0.5

        # exclude edge of token to itself

        indices = repeat(torch.arange(n, device = device), 'i -> b i j', b = b, j = n)
        rel_pos  = rearrange(coors, 'b n d -> b n () d') - rearrange(coors, 'b n d -> b () n d')

        indices = indices.masked_select(exclude_self_mask).reshape(b, n, n - 1)
        rel_pos = rel_pos.masked_select(exclude_self_mask[..., None]).reshape(b, n, n - 1, 3)

        if exists(mask):
            mask = rearrange(mask, 'b i -> b i ()') * rearrange(mask, 'b j -> b () j')
            mask = mask.masked_select(exclude_self_mask).reshape(b, n, n - 1)

        if exists(edges):
            if exists(self.edge_emb):
                edges = self.edge_emb(edges)

            edges = edges.masked_select(exclude_self_mask[..., None]).reshape(b, n, n - 1, -1)

        if exists(self.adj_emb):
            adj_emb = self.adj_emb(adj_indices)
            edges = torch.cat((edges, adj_emb), dim = -1) if exists(edges) else adj_emb

        rel_dist = rel_pos.norm(dim = -1)

        # rel_dist gets modified using adjacency or neighbor mask

        modified_rel_dist = rel_dist.clone()
        max_value = get_max_value(modified_rel_dist) # for masking out nodes from being considered as neighbors

        # neighbors

        if exists(neighbor_mask):
            neighbor_mask = remove_self(neighbor_mask)

            max_neighbors = neighbor_mask.sum(dim = -1).max().item()
            if max_neighbors > neighbors:
                print(f'neighbor_mask shows maximum number of neighbors as {max_neighbors} but specified number of neighbors is {neighbors}')

            modified_rel_dist.masked_fill_(~neighbor_mask, max_value)

        # use sparse neighbor mask to assign priority of bonded

        if exists(sparse_neighbor_mask):
            modified_rel_dist.masked_fill_(sparse_neighbor_mask, 0.)

        # mask out future nodes to high distance if causal turned on

        if self.causal:
            causal_mask = torch.ones(n, n - 1, device = device).triu().bool()
            modified_rel_dist.masked_fill_(causal_mask[None, ...], max_value)

        # if number of local neighbors by distance is set to 0, then only fetch the sparse neighbors defined by adjacency matrix

        if neighbors == 0:
            valid_radius = 0

        # get neighbors and neighbor mask, excluding self

        neighbors = int(min(neighbors, n - 1))
        total_neighbors = int(neighbors + num_sparse_neighbors)
        assert total_neighbors > 0, 'you must be fetching at least 1 neighbor'

        total_neighbors = int(min(total_neighbors, n - 1)) # make sure total neighbors does not exceed the length of the sequence itself

        dist_values, nearest_indices = modified_rel_dist.topk(total_neighbors, dim = -1, largest = False)
        neighbor_mask = dist_values <= valid_radius

        neighbor_rel_dist = batched_index_select(rel_dist, nearest_indices, dim = 2)
        neighbor_rel_pos = batched_index_select(rel_pos, nearest_indices, dim = 2)
        neighbor_indices = batched_index_select(indices, nearest_indices, dim = 2)

        if exists(mask):
            neighbor_mask = neighbor_mask & batched_index_select(mask, nearest_indices, dim = 2)

        if exists(edges):
            edges = batched_index_select(edges, nearest_indices, dim = 2)

        # calculate rotary pos emb

        rotary_pos_emb = None
        rotary_query_pos_emb = None
        rotary_key_pos_emb = None

        if self.rotary_position:
            seq = torch.arange(n, device = device)
            seq_pos_emb = self.rotary_pos_emb(seq)
            self_indices = torch.arange(neighbor_indices.shape[1], device = device)
            self_indices = repeat(self_indices, 'i -> b i ()', b = b)
            neighbor_indices_with_self = torch.cat((self_indices, neighbor_indices), dim = 2)
            pos_emb = batched_index_select(seq_pos_emb, neighbor_indices_with_self, dim = 0)

            rotary_key_pos_emb = pos_emb
            rotary_query_pos_emb = repeat(seq_pos_emb, 'n d -> b n d', b = b)

        if self.rotary_rel_dist:
            neighbor_rel_dist_with_self = F.pad(neighbor_rel_dist, (1, 0), value = 0) * 1e2
            rel_dist_pos_emb = self.rotary_pos_emb(neighbor_rel_dist_with_self)
            rotary_key_pos_emb = safe_cat(rotary_key_pos_emb, rel_dist_pos_emb, dim = -1)

            query_dist = torch.zeros(n, device = device)
            query_pos_emb = self.rotary_pos_emb(query_dist)
            query_pos_emb = repeat(query_pos_emb, 'n d -> b n d', b = b)

            rotary_query_pos_emb = safe_cat(rotary_query_pos_emb, query_pos_emb, dim = -1)

        if exists(rotary_query_pos_emb) and exists(rotary_key_pos_emb):
            rotary_pos_emb = (rotary_query_pos_emb, rotary_key_pos_emb)

        # calculate basis

        basis = get_basis(neighbor_rel_pos, num_degrees - 1, differentiable = self.differentiable_coors)

        # main logic

        edge_info = (neighbor_indices, neighbor_mask, edges)
        x = feats

        # project in

        x = self.conv_in(x, edge_info, rel_dist = neighbor_rel_dist, basis = basis)

        # preconvolution layers

        for conv, nonlin in self.convs:
            x = nonlin(x)
            x = conv(x, edge_info, rel_dist = neighbor_rel_dist, basis = basis)

        # transformer layers

        x = self.net(x, edge_info = edge_info, rel_dist = neighbor_rel_dist, basis = basis, global_feats = global_feats, pos_emb = rotary_pos_emb)

        # project out

        x = self.conv_out(x, edge_info, rel_dist = neighbor_rel_dist, basis = basis)

        # norm

        x = self.norm(x)

        # reduce dim if specified

        if exists(self.linear_out):
            x = self.linear_out(x)
            x = map_values(lambda t: t.squeeze(dim = 2), x)

        if return_pooled:
            mask_fn = (lambda t: masked_mean(t, _mask, dim = 1)) if exists(_mask) else (lambda t: t.mean(dim = 1))
            x = map_values(mask_fn, x)

        if '0' in x:
            x['0'] = x['0'].squeeze(dim = -1)

        if exists(return_type):
            return x[str(return_type)]

        return x
    def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos_emb = None):
        h, attend_self = self.heads, self.attend_self
        device, dtype = get_tensor_device_and_dtype(features)
        neighbor_indices, neighbor_mask, edges = edge_info

        max_neg_value = -torch.finfo().max

        if exists(neighbor_mask):
            neighbor_mask = rearrange(neighbor_mask, 'b i j -> b () i j')

        queries = self.to_q(features)
        values  = self.to_v(features, edge_info, rel_dist, basis)

        if self.linear_proj_keys:
            keys = self.to_k(features)
            keys = map_values(lambda val: batched_index_select(val, neighbor_indices, dim = 1), keys)
        elif not exists(self.to_k):
            keys = values
        else:
            keys = self.to_k(features, edge_info, rel_dist, basis)

        if attend_self:
            self_keys, self_values = self.to_self_k(features), self.to_self_v(features)

        if exists(global_feats):
            global_keys, global_values = self.to_global_k(global_feats), self.to_global_v(global_feats)

        outputs = {}
        for degree in features.keys():
            q, k, v = map(lambda t: t[degree], (queries, keys, values))

            q = rearrange(q, 'b i (h d) m -> b h i d m', h = h)

            if attend_self:
                self_k, self_v = map(lambda t: t[degree], (self_keys, self_values))
                self_k, self_v = map(lambda t: rearrange(t, 'b n d m -> b n () d m'), (self_k, self_v))
                k = torch.cat((self_k, k), dim = 2)
                v = torch.cat((self_v, v), dim = 2)

            if exists(pos_emb) and degree == '0':
                query_pos_emb, key_pos_emb = pos_emb
                query_pos_emb = rearrange(query_pos_emb, 'b i d -> b () i d ()')
                key_pos_emb = rearrange(key_pos_emb, 'b i j d -> b i j d ()')
                q = apply_rotary_pos_emb(q, query_pos_emb)
                k = apply_rotary_pos_emb(k, key_pos_emb)
                v = apply_rotary_pos_emb(v, key_pos_emb)

            if self.use_null_kv:
                null_k, null_v = map(lambda t: t[degree], (self.null_keys, self.null_values))
                null_k, null_v = map(lambda t: repeat(t, 'd m -> b i () d m', b = q.shape[0], i = q.shape[2]), (null_k, null_v))
                k = torch.cat((null_k, k), dim = 2)
                v = torch.cat((null_v, v), dim = 2)

            if exists(global_feats) and degree == '0':
                global_k, global_v = map(lambda t: t[degree], (global_keys, global_values))
                global_k, global_v = map(lambda t: repeat(t, 'b j d m -> b i j d m', i = k.shape[1]), (global_k, global_v))
                k = torch.cat((global_k, k), dim = 2)
                v = torch.cat((global_v, v), dim = 2)

            sim = einsum('b h i d m, b i j d m -> b h i j', q, k) * self.scale

            if exists(neighbor_mask):
                num_left_pad = sim.shape[-1] - neighbor_mask.shape[-1]
                mask = F.pad(neighbor_mask, (num_left_pad, 0), value = True)
                sim.masked_fill_(~mask, max_neg_value)

            attn = sim.softmax(dim = -1)
            out = einsum('b h i j, b i j d m -> b h i d m', attn, v)
            outputs[degree] = rearrange(out, 'b h n d m -> b n (h d) m')

        return self.to_out(outputs)
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        dim_head = 24,
        depth = 2,
        input_degrees = 1,
        num_degrees = 2,
        output_degrees = 1,
        valid_radius = 1e5,
        reduce_dim_out = False,
        num_tokens = None,
        num_edge_tokens = None,
        edge_dim = None,
        reversible = False,
        attend_self = True,
        use_null_kv = False,
        differentiable_coors = False,
        fourier_encode_dist = False,
        rel_dist_num_fourier_features = 4,
        num_neighbors = float('inf'),
        attend_sparse_neighbors = False,
        num_adj_degrees = None,
        adj_dim = 0,
        max_sparse_neighbors = float('inf'),
        dim_in = None,
        dim_out = None,
        norm_out = False,
        num_conv_layers = 0,
        causal = False,
        splits = 4,
        global_feats_dim = None,
        linear_proj_keys = False,
        one_headed_key_values = False,
        tie_key_values = False,
        rotary_position = False,
        rotary_rel_dist = False
    ):
        super().__init__()
        dim_in = default(dim_in, dim)
        self.dim_in = cast_tuple(dim_in, input_degrees)
        self.dim = dim

        # token embedding

        self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None

        # positional embedding

        self.rotary_rel_dist = rotary_rel_dist
        self.rotary_position = rotary_position

        self.rotary_pos_emb = None
        if rotary_position or rotary_rel_dist:
            num_rotaries = int(rotary_position) + int(rotary_rel_dist)
            self.rotary_pos_emb = SinusoidalEmbeddings(dim_head // num_rotaries)

        # edges

        assert not (exists(num_edge_tokens) and not exists(edge_dim)), 'edge dimension (edge_dim) must be supplied if SE3 transformer is to have edge tokens'
        self.edge_emb = nn.Embedding(num_edge_tokens, edge_dim) if exists(num_edge_tokens) else None
        self.has_edges = exists(edge_dim) and edge_dim > 0

        self.input_degrees = input_degrees
        self.num_degrees = num_degrees
        self.output_degrees = output_degrees

        # whether to differentiate through basis, needed for alphafold2

        self.differentiable_coors = differentiable_coors

        # neighbors hyperparameters

        self.valid_radius = valid_radius
        self.num_neighbors = num_neighbors

        # sparse neighbors, derived from adjacency matrix or edges being passed in

        self.attend_sparse_neighbors = attend_sparse_neighbors
        self.max_sparse_neighbors = max_sparse_neighbors

        # adjacent neighbor derivation and embed

        assert not (exists(num_adj_degrees) and num_adj_degrees < 1), 'make sure adjacent degrees is greater than 1'
        self.num_adj_degrees = num_adj_degrees
        self.adj_emb = nn.Embedding(num_adj_degrees + 1, adj_dim) if exists(num_adj_degrees) and adj_dim > 0 else None

        edge_dim = (edge_dim if self.has_edges else 0) + (adj_dim if exists(self.adj_emb) else 0)

        # define fibers and dimensionality

        dim_in = default(dim_in, dim)
        dim_out = default(dim_out, dim)

        fiber_in     = Fiber.create(input_degrees, dim_in)
        fiber_hidden = Fiber.create(num_degrees, dim)
        fiber_out    = Fiber.create(output_degrees, dim_out)

        conv_kwargs = dict(edge_dim = edge_dim, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits)

        # causal

        assert not (causal and not attend_self), 'attending to self must be turned on if in autoregressive mode (for the first token)'
        self.causal = causal

        # main network

        self.conv_in  = ConvSE3(fiber_in, fiber_hidden, **conv_kwargs)

        # pre-convs

        self.convs = nn.ModuleList([])
        for _ in range(num_conv_layers):
            self.convs.append(nn.ModuleList([
                ConvSE3(fiber_hidden, fiber_hidden, **conv_kwargs),
                NormSE3(fiber_hidden)
            ]))

        # global features

        self.accept_global_feats = exists(global_feats_dim)
        assert not (reversible and self.accept_global_feats), 'reversibility and global features are not compatible'

        # trunk

        self.attend_self = attend_self

        attention_klass = OneHeadedKVAttentionSE3 if one_headed_key_values else AttentionSE3

        layers = nn.ModuleList([])
        for _ in range(depth):
            layers.append(nn.ModuleList([
                AttentionBlockSE3(fiber_hidden, heads = heads, dim_head = dim_head, attend_self = attend_self, edge_dim = edge_dim, fourier_encode_dist = fourier_encode_dist, rel_dist_num_fourier_features = rel_dist_num_fourier_features, use_null_kv = use_null_kv, splits = splits, global_feats_dim = global_feats_dim, linear_proj_keys = linear_proj_keys, attention_klass = attention_klass, tie_key_values = tie_key_values),
                FeedForwardBlockSE3(fiber_hidden)
            ]))

        execution_class = ReversibleSequence if reversible else SequentialSequence
        self.net = execution_class(layers)

        # out

        self.conv_out = ConvSE3(fiber_hidden, fiber_out, **conv_kwargs)

        self.norm = NormSE3(fiber_out, nonlin = nn.Identity()) if norm_out or reversible else nn.Identity()

        self.linear_out = LinearSE3(
            fiber_out,
            Fiber.create(output_degrees, 1)
        ) if reduce_dim_out else None
    def forward(
        self,
        inp,
        edge_info,
        rel_dist = None,
        basis = None
    ):
        splits = self.splits
        neighbor_indices, neighbor_masks, edges = edge_info
        rel_dist = rearrange(rel_dist, 'b m n -> b m n ()')

        kernels = {}
        outputs = {}

        if self.fourier_encode_dist:
            rel_dist = fourier_encode(rel_dist[..., None], num_encodings = self.num_fourier_features)

        # split basis

        basis_keys = basis.keys()
        split_basis_values = list(zip(*list(map(lambda t: fast_split(t, splits, dim = 1), basis.values()))))
        split_basis = list(map(lambda v: dict(zip(basis_keys, v)), split_basis_values))

        # go through every permutation of input degree type to output degree type

        for degree_out in self.fiber_out.degrees:
            output = 0
            degree_out_key = str(degree_out)

            for degree_in, m_in in self.fiber_in:
                etype = f'({degree_in},{degree_out})'

                x = inp[str(degree_in)]

                x = batched_index_select(x, neighbor_indices, dim = 1)
                x = x.view(*x.shape[:3], to_order(degree_in) * m_in, 1)

                kernel_fn = self.kernel_unary[etype]
                edge_features = torch.cat((rel_dist, edges), dim = -1) if exists(edges) else rel_dist

                output_chunk = None
                split_x = fast_split(x, splits, dim = 1)
                split_edge_features = fast_split(edge_features, splits, dim = 1)

                # process input, edges, and basis in chunks along the sequence dimension

                for x_chunk, edge_features, basis in zip(split_x, split_edge_features, split_basis):
                    kernel = kernel_fn(edge_features, basis = basis)
                    chunk = einsum('... o i, ... i c -> ... o c', kernel, x_chunk)
                    output_chunk = safe_cat(output_chunk, chunk, dim = 1)

                output = output + output_chunk

            if self.pool:
                output = masked_mean(output, neighbor_masks, dim = 2) if exists(neighbor_masks) else output.mean(dim = 2)

            leading_shape = x.shape[:2] if self.pool else x.shape[:3]
            output = output.view(*leading_shape, -1, to_order(degree_out))

            outputs[degree_out_key] = output

        if self.self_interaction:
            self_interact_out = self.self_interact(inp)
            outputs = self.self_interact_sum(outputs, self_interact_out)

        return outputs
    def __init__(
        self,
        fiber,
        dim_head = 64,
        heads = 8,
        attend_self = False,
        edge_dim = None,
        fourier_encode_dist = False,
        rel_dist_num_fourier_features = 4,
        use_null_kv = False,
        splits = 4,
        global_feats_dim = None,
        linear_proj_keys = False,
        tie_key_values = False
    ):
        super().__init__()
        hidden_dim = dim_head * heads
        hidden_fiber = Fiber(list(map(lambda t: (t[0], hidden_dim), fiber)))
        kv_hidden_fiber = Fiber(list(map(lambda t: (t[0], dim_head), fiber)))
        project_out = not (heads == 1 and len(fiber.dims) == 1 and dim_head == fiber.dims[0])

        self.scale = dim_head ** -0.5
        self.heads = heads

        self.linear_proj_keys = linear_proj_keys # whether to linearly project features for keys, rather than convolve with basis

        self.to_q = LinearSE3(fiber, hidden_fiber)
        self.to_v = ConvSE3(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits)

        assert not (linear_proj_keys and tie_key_values), 'you cannot do linear projection of keys and have shared key / values turned on at the same time'

        if linear_proj_keys:
            self.to_k = LinearSE3(fiber, kv_hidden_fiber)
        elif not tie_key_values:
            self.to_k = ConvSE3(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits)
        else:
            self.to_k = None

        self.to_out = LinearSE3(hidden_fiber, fiber) if project_out else nn.Identity()

        self.use_null_kv = use_null_kv
        if use_null_kv:
            self.null_keys = nn.ParameterDict()
            self.null_values = nn.ParameterDict()

            for degree in fiber.degrees:
                m = to_order(degree)
                degree_key = str(degree)
                self.null_keys[degree_key] = nn.Parameter(torch.zeros(dim_head, m))
                self.null_values[degree_key] = nn.Parameter(torch.zeros(dim_head, m))

        self.attend_self = attend_self
        if attend_self:
            self.to_self_k = LinearSE3(fiber, kv_hidden_fiber)
            self.to_self_v = LinearSE3(fiber, kv_hidden_fiber)

        self.accept_global_feats = exists(global_feats_dim)
        if self.accept_global_feats:
            global_input_fiber = Fiber.create(1, global_feats_dim)
            global_output_fiber = Fiber.create(1, kv_hidden_fiber[0])
            self.to_global_k = LinearSE3(global_input_fiber, global_output_fiber)
            self.to_global_v = LinearSE3(global_input_fiber, global_output_fiber)
Exemplo n.º 10
0
import os
from math import pi
import torch
from torch import einsum
from einops import rearrange
from itertools import product
from contextlib import contextmanager

from se3_transformer_pytorch.irr_repr import irr_repr, spherical_harmonics
from se3_transformer_pytorch.utils import torch_default_dtype, cache_dir, exists, to_order
from se3_transformer_pytorch.spherical_harmonics import clear_spherical_harmonics_cache

# constants

CACHE_PATH = os.path.expanduser(
    '~/.cache.equivariant_attention') if not exists(
        os.environ.get('CLEAR_CACHE')) else None

# todo (figure ot why this was hard coded in official repo)

RANDOM_ANGLES = [[4.41301023, 5.56684102, 4.59384642],
                 [4.93325116, 6.12697327, 4.14574096],
                 [0.53878964, 4.09050444, 5.36539036],
                 [2.16017393, 3.48835314, 5.55174441],
                 [2.52385107, 0.2908958, 3.90040975]]

# helpers


@contextmanager
def null_context():
    yield
Exemplo n.º 11
0
from math import pi
import torch
from torch import einsum
from einops import rearrange
from itertools import product
from contextlib import contextmanager

from se3_transformer_pytorch.irr_repr import irr_repr, spherical_harmonics
from se3_transformer_pytorch.utils import torch_default_dtype, cache_dir, exists, default, to_order
from se3_transformer_pytorch.spherical_harmonics import clear_spherical_harmonics_cache

# constants

CACHE_PATH = default(os.getenv('CACHE_PATH'),
                     os.path.expanduser('~/.cache.equivariant_attention'))
CACHE_PATH = CACHE_PATH if not exists(os.environ.get('CLEAR_CACHE')) else None

# todo (figure ot why this was hard coded in official repo)

RANDOM_ANGLES = [[4.41301023, 5.56684102, 4.59384642],
                 [4.93325116, 6.12697327, 4.14574096],
                 [0.53878964, 4.09050444, 5.36539036],
                 [2.16017393, 3.48835314, 5.55174441],
                 [2.52385107, 0.2908958, 3.90040975]]

# helpers


@contextmanager
def null_context():
    yield
Exemplo n.º 12
0
    def forward(self, feats, coors, mask=None, edges=None, return_type=None):
        if exists(self.token_emb):
            feats = self.token_emb(feats)

        assert not (
            exists(edges) and not exists(self.edge_emb)
        ), 'edge embedding (num_edge_tokens & edge_dim) must be supplied if one were to train on edge types'

        if exists(edges):
            edges = self.edge_emb(edges)

        if torch.is_tensor(feats):
            feats = {'0': feats[..., None]}

        b, n, d, *_, device = *feats['0'].shape, feats['0'].device

        assert d == self.dim, f'feature dimension {d} must be equal to dimension given at init {self.dim}'
        assert set(map(int, feats.keys())) == set(
            range(self.input_degrees
                  )), f'input must have {self.input_degrees} degree'

        num_degrees, neighbors = self.num_degrees, self.num_neighbors
        neighbors = min(neighbors, n - 1)

        # exclude edge of token to itself

        exclude_self_mask = rearrange(
            ~torch.eye(n, dtype=torch.bool, device=device), 'i j -> () i j')
        indices = repeat(torch.arange(n, device=device),
                         'i -> b i j',
                         b=b,
                         j=n)
        rel_pos = rearrange(coors, 'b n d -> b n () d') - rearrange(
            coors, 'b n d -> b () n d')

        indices = indices.masked_select(exclude_self_mask).reshape(b, n, n - 1)
        rel_pos = rel_pos.masked_select(exclude_self_mask[..., None]).reshape(
            b, n, n - 1, 3)

        if exists(mask):
            mask = rearrange(mask, 'b i -> b i ()') * rearrange(
                mask, 'b j -> b () j')
            mask = mask.masked_select(exclude_self_mask).reshape(b, n, n - 1)

        if exists(edges):
            edges = edges.masked_select(exclude_self_mask[..., None]).reshape(
                b, n, n - 1, -1)

        rel_dist = rel_pos.norm(dim=-1)

        # get neighbors and neighbor mask, excluding self

        neighbor_rel_dist, nearest_indices = rel_dist.topk(neighbors,
                                                           dim=-1,
                                                           largest=False)
        neighbor_rel_pos = batched_index_select(rel_pos,
                                                nearest_indices,
                                                dim=2)
        neighbor_indices = batched_index_select(indices,
                                                nearest_indices,
                                                dim=2)

        basis = get_basis(neighbor_rel_pos,
                          num_degrees - 1,
                          differentiable=self.differentiable_coors)

        neighbor_mask = neighbor_rel_dist <= self.valid_radius

        if exists(mask):
            neighbor_mask = neighbor_mask & batched_index_select(
                mask, nearest_indices, dim=2)

        if exists(edges):
            edges = batched_index_select(edges, nearest_indices, dim=2)

        # main logic

        edge_info = (neighbor_indices, neighbor_mask, edges)
        x = feats

        # project in

        x = self.conv_in(x, edge_info, rel_dist=neighbor_rel_dist, basis=basis)

        # transformer layers

        x = self.net(x,
                     edge_info=edge_info,
                     rel_dist=neighbor_rel_dist,
                     basis=basis)

        # project out

        x = self.conv_out(x,
                          edge_info,
                          rel_dist=neighbor_rel_dist,
                          basis=basis)

        # norm

        x = self.norm(x)

        # reduce dim if specified

        if exists(self.linear_out):
            x = self.linear_out(x)
            x = {k: v.squeeze(dim=2) for k, v in x.items()}

        if exists(return_type):
            return x[str(return_type)]

        return x
Exemplo n.º 13
0
    def __init__(self,
                 *,
                 dim,
                 num_neighbors=12,
                 heads=8,
                 dim_head=64,
                 depth=2,
                 num_degrees=2,
                 input_degrees=1,
                 output_degrees=2,
                 valid_radius=1e5,
                 reduce_dim_out=False,
                 num_tokens=None,
                 num_edge_tokens=None,
                 edge_dim=None,
                 reversible=False,
                 attend_self=False,
                 use_null_kv=False,
                 differentiable_coors=False):
        super().__init__()
        assert num_neighbors > 0, 'neighbors must be at least 1'
        self.dim = dim
        self.valid_radius = valid_radius

        self.token_emb = None
        self.token_emb = nn.Embedding(num_tokens,
                                      dim) if exists(num_tokens) else None

        assert not (
            exists(num_edge_tokens) and not exists(edge_dim)
        ), 'edge dimension (edge_dim) must be supplied if SE3 transformer is to have edge tokens'
        self.edge_emb = nn.Embedding(
            num_edge_tokens, edge_dim) if exists(num_edge_tokens) else None

        self.input_degrees = input_degrees
        self.num_degrees = num_degrees
        self.num_neighbors = num_neighbors

        fiber_in = Fiber.create(input_degrees, dim)
        fiber_hidden = Fiber.create(num_degrees, dim)
        fiber_out = Fiber.create(output_degrees, dim)

        self.conv_in = ConvSE3(fiber_in, fiber_hidden, edge_dim=edge_dim)

        layers = nn.ModuleList([])
        for _ in range(depth):
            layers.append(
                nn.ModuleList([
                    AttentionBlockSE3(fiber_hidden,
                                      heads=heads,
                                      dim_head=dim_head,
                                      attend_self=attend_self,
                                      edge_dim=edge_dim,
                                      use_null_kv=use_null_kv),
                    FeedForwardBlockSE3(fiber_hidden)
                ]))

        execution_class = ReversibleSequence if reversible else SequentialSequence
        self.net = execution_class(layers)

        self.conv_out = ConvSE3(fiber_hidden, fiber_out, edge_dim=edge_dim)

        self.norm = NormSE3(fiber_out)

        self.linear_out = LinearSE3(fiber_out, Fiber.create(
            output_degrees, 1)) if reduce_dim_out else None

        self.differentiable_coors = differentiable_coors
Exemplo n.º 14
0
    def __init__(self,
                 fiber,
                 dim_head=64,
                 heads=8,
                 attend_self=False,
                 edge_dim=None,
                 fourier_encode_dist=False,
                 rel_dist_num_fourier_features=4,
                 use_null_kv=False,
                 splits=4,
                 global_feats_dim=None):
        super().__init__()
        hidden_dim = dim_head * heads
        hidden_fiber = Fiber(list(map(lambda t: (t[0], hidden_dim), fiber)))
        project_out = not (heads == 1 and len(fiber.dims) == 1
                           and dim_head == fiber.dims[0])

        self.scale = dim_head**-0.5
        self.heads = heads

        self.to_q = LinearSE3(fiber, hidden_fiber)
        self.to_k = ConvSE3(fiber,
                            hidden_fiber,
                            edge_dim=edge_dim,
                            pool=False,
                            self_interaction=False,
                            fourier_encode_dist=fourier_encode_dist,
                            num_fourier_features=rel_dist_num_fourier_features,
                            splits=splits)
        self.to_v = ConvSE3(fiber,
                            hidden_fiber,
                            edge_dim=edge_dim,
                            pool=False,
                            self_interaction=False,
                            fourier_encode_dist=fourier_encode_dist,
                            num_fourier_features=rel_dist_num_fourier_features,
                            splits=splits)
        self.to_out = LinearSE3(hidden_fiber,
                                fiber) if project_out else nn.Identity()

        self.use_null_kv = use_null_kv
        if use_null_kv:
            self.null_keys = nn.ParameterDict()
            self.null_values = nn.ParameterDict()

            for degree in fiber.degrees:
                m = to_order(degree)
                degree_key = str(degree)
                self.null_keys[degree_key] = nn.Parameter(
                    torch.zeros(heads, dim_head, m))
                self.null_values[degree_key] = nn.Parameter(
                    torch.zeros(heads, dim_head, m))

        self.attend_self = attend_self
        if attend_self:
            self.to_self_k = LinearSE3(fiber, hidden_fiber)
            self.to_self_v = LinearSE3(fiber, hidden_fiber)

        self.accept_global_feats = exists(global_feats_dim)
        if self.accept_global_feats:
            global_input_fiber = Fiber.create(1, global_feats_dim)
            global_output_fiber = Fiber.create(1, hidden_fiber[0])
            self.to_global_k = LinearSE3(global_input_fiber,
                                         global_output_fiber)
            self.to_global_v = LinearSE3(global_input_fiber,
                                         global_output_fiber)