def forward(self, inp, edge_info, rel_dist=None, basis=None): neighbor_indices, neighbor_masks, edges = edge_info rel_dist = rearrange(rel_dist, 'b m n -> b m n ()') kernels = {} outputs = {} for (di, mi), (do, mo) in (self.fiber_in * self.fiber_out): etype = f'({di},{do})' kernel_fn = self.kernel_unary[etype] edge_features = torch.cat( (rel_dist, edges), dim=-1) if exists(edges) else rel_dist kernels[etype] = kernel_fn(edge_features, basis=basis) for degree_out in self.fiber_out.degrees: output = 0 degree_out_key = str(degree_out) for degree_in, m_in in self.fiber_in: x = inp[str(degree_in)] x = batched_index_select(x, neighbor_indices, dim=1) x = x.view(*x.shape[:3], to_order(degree_in) * m_in, 1) etype = f'({degree_in},{degree_out})' kernel = kernels[etype] output = output + einsum('... o i, ... i c -> ... o c', kernel, x) if self.pool: output = masked_mean( output, neighbor_masks, dim=2) if exists(neighbor_masks) else output.mean(dim=2) leading_shape = x.shape[:2] if self.pool else x.shape[:3] output = output.view(*leading_shape, -1, to_order(degree_out)) outputs[degree_out_key] = output if self.self_interaction: self_interact_out = self.self_interact(inp) outputs = self.self_interact_sum(outputs, self_interact_out) return outputs
def forward(self, feats, coors, mask=None, adj_mat=None, edges=None, return_type=None, return_pooled=False): _mask = mask if self.output_degrees == 1: return_type = 0 if exists(self.token_emb): feats = self.token_emb(feats) assert not ( self.attend_sparse_neighbors and not exists(adj_mat) ), 'adjacency matrix (adjacency_mat) or edges (edges) must be passed in' assert not ( exists(edges) and not exists(self.edge_emb) ), 'edge embedding (num_edge_tokens & edge_dim) must be supplied if one were to train on edge types' if torch.is_tensor(feats): feats = {'0': feats[..., None]} b, n, d, *_, device = *feats['0'].shape, feats['0'].device assert d == self.dim, f'feature dimension {d} must be equal to dimension given at init {self.dim}' assert set(map(int, feats.keys())) == set( range(self.input_degrees )), f'input must have {self.input_degrees} degree' num_degrees, neighbors, max_sparse_neighbors, valid_radius = self.num_degrees, self.num_neighbors, self.max_sparse_neighbors, self.valid_radius assert self.attend_sparse_neighbors or neighbors > 0, 'you must either attend to sparsely bonded neighbors, or set number of locally attended neighbors to be greater than 0' # create N-degrees adjacent matrix from 1st degree connections if exists(self.num_adj_degrees): if len(adj_mat.shape) == 2: adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b=b) adj_indices = adj_mat.clone().long() for ind in range(self.num_adj_degrees - 1): degree = ind + 2 next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0 next_degree_mask = (next_degree_adj_mat.float() - adj_mat.float()).bool() adj_indices.masked_fill_(next_degree_mask, degree) adj_mat = next_degree_adj_mat.clone() # se3 transformer by default cannot have a node attend to itself exclude_self_mask = rearrange( ~torch.eye(n, dtype=torch.bool, device=device), 'i j -> () i j') # calculate sparsely connected neighbors sparse_neighbor_mask = None num_sparse_neighbors = 0 if self.attend_sparse_neighbors: assert exists( adj_mat ), 'adjacency matrix must be passed in (keyword argument adj_mat)' if exists(adj_mat): if len(adj_mat) == 2: adj_mat = repeat(adj_mat, 'i j -> b i j', b=b) adj_mat = adj_mat.masked_select(exclude_self_mask).reshape( b, n, n - 1) adj_mat_values = adj_mat.float() adj_mat_max_neighbors = adj_mat_values.sum(dim=-1).max().item() if max_sparse_neighbors < adj_mat_max_neighbors: noise = torch.empty_like(adj_mat_values).uniform_(-0.01, 0.01) adj_mat_values += noise num_sparse_neighbors = int( min(max_sparse_neighbors, adj_mat_max_neighbors)) values, indices = adj_mat_values.topk(num_sparse_neighbors, dim=-1) sparse_neighbor_mask = torch.zeros_like(adj_mat_values).scatter_( -1, indices, values) sparse_neighbor_mask = sparse_neighbor_mask > 0.5 # exclude edge of token to itself indices = repeat(torch.arange(n, device=device), 'i -> b i j', b=b, j=n) rel_pos = rearrange(coors, 'b n d -> b n () d') - rearrange( coors, 'b n d -> b () n d') indices = indices.masked_select(exclude_self_mask).reshape(b, n, n - 1) rel_pos = rel_pos.masked_select(exclude_self_mask[..., None]).reshape( b, n, n - 1, 3) if exists(mask): mask = rearrange(mask, 'b i -> b i ()') * rearrange( mask, 'b j -> b () j') mask = mask.masked_select(exclude_self_mask).reshape(b, n, n - 1) if exists(edges): edges = self.edge_emb(edges) edges = edges.masked_select(exclude_self_mask[..., None]).reshape( b, n, n - 1, -1) if exists(self.adj_emb): adj_emb = self.adj_emb(adj_indices) edges = torch.cat( (edges, adj_emb), dim=-1) if exists(edges) else adj_emb rel_dist = rel_pos.norm(dim=-1) # use sparse neighbor mask to assign priority of bonded modified_rel_dist = rel_dist if exists(sparse_neighbor_mask): modified_rel_dist.masked_fill_(sparse_neighbor_mask, 0.) # if number of local neighbors by distance is set to 0, then only fetch the sparse neighbors defined by adjacency matrix if neighbors == 0: valid_radius = 0 # get neighbors and neighbor mask, excluding self neighbors = int(min(neighbors, n - 1)) total_neighbors = int(neighbors + num_sparse_neighbors) assert total_neighbors > 0, 'you must be fetching at least 1 neighbor' total_neighbors = int( min(total_neighbors, n - 1) ) # make sure total neighbors does not exceed the length of the sequence itself dist_values, nearest_indices = modified_rel_dist.topk(total_neighbors, dim=-1, largest=False) neighbor_mask = dist_values <= valid_radius neighbor_rel_dist = batched_index_select(rel_dist, nearest_indices, dim=2) neighbor_rel_pos = batched_index_select(rel_pos, nearest_indices, dim=2) neighbor_indices = batched_index_select(indices, nearest_indices, dim=2) if exists(mask): neighbor_mask = neighbor_mask & batched_index_select( mask, nearest_indices, dim=2) if exists(edges): edges = batched_index_select(edges, nearest_indices, dim=2) # calculate basis basis = get_basis(neighbor_rel_pos, num_degrees - 1, differentiable=self.differentiable_coors) # main logic edge_info = (neighbor_indices, neighbor_mask, edges) x = feats # project in x = self.conv_in(x, edge_info, rel_dist=neighbor_rel_dist, basis=basis) # transformer layers x = self.net(x, edge_info=edge_info, rel_dist=neighbor_rel_dist, basis=basis) # project out x = self.conv_out(x, edge_info, rel_dist=neighbor_rel_dist, basis=basis) # norm x = self.norm(x) # reduce dim if specified if exists(self.linear_out): x = self.linear_out(x) x = map_values(lambda t: t.squeeze(dim=2), x) if return_pooled: mask_fn = (lambda t: masked_mean(t, _mask, dim=1) ) if exists(_mask) else (lambda t: t.mean(dim=1)) x = map_values(mask_fn, x) if '0' in x: x['0'] = x['0'].squeeze(dim=-1) if exists(return_type): return x[str(return_type)] return x
def __init__(self, *, dim, heads=8, dim_head=64, depth=2, input_degrees=1, num_degrees=2, output_degrees=1, valid_radius=1e5, reduce_dim_out=False, num_tokens=None, num_edge_tokens=None, edge_dim=None, reversible=False, attend_self=True, use_null_kv=False, differentiable_coors=False, fourier_encode_dist=False, num_neighbors=float('inf'), attend_sparse_neighbors=False, num_adj_degrees=None, adj_dim=0, max_sparse_neighbors=float('inf')): super().__init__() self.dim = dim self.token_emb = None self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None assert not ( exists(num_edge_tokens) and not exists(edge_dim) ), 'edge dimension (edge_dim) must be supplied if SE3 transformer is to have edge tokens' self.edge_emb = nn.Embedding( num_edge_tokens, edge_dim) if exists(num_edge_tokens) else None self.input_degrees = input_degrees self.num_degrees = num_degrees self.output_degrees = output_degrees # whether to differentiate through basis, needed for alphafold2 self.differentiable_coors = differentiable_coors # neighbors hyperparameters self.valid_radius = valid_radius self.num_neighbors = num_neighbors # sparse neighbors, derived from adjacency matrix or edges being passed in self.attend_sparse_neighbors = attend_sparse_neighbors self.max_sparse_neighbors = max_sparse_neighbors # adjacent neighbor derivation and embed assert not (exists(num_adj_degrees) and num_adj_degrees < 1 ), 'make sure adjacent degrees is greater than 1' self.num_adj_degrees = num_adj_degrees self.adj_emb = nn.Embedding( num_adj_degrees + 1, adj_dim) if exists(num_adj_degrees) and adj_dim > 0 else None edge_dim = (edge_dim if exists(self.edge_emb) else 0) + (adj_dim if exists(self.adj_emb) else 0) # main network fiber_in = Fiber.create(input_degrees, dim) fiber_hidden = Fiber.create(num_degrees, dim) fiber_out = Fiber.create(output_degrees, dim) self.conv_in = ConvSE3(fiber_in, fiber_hidden, edge_dim=edge_dim, fourier_encode_dist=fourier_encode_dist) layers = nn.ModuleList([]) for _ in range(depth): layers.append( nn.ModuleList([ AttentionBlockSE3(fiber_hidden, heads=heads, dim_head=dim_head, attend_self=attend_self, edge_dim=edge_dim, fourier_encode_dist=fourier_encode_dist, use_null_kv=use_null_kv), FeedForwardBlockSE3(fiber_hidden) ])) execution_class = ReversibleSequence if reversible else SequentialSequence self.net = execution_class(layers) self.conv_out = ConvSE3(fiber_hidden, fiber_out, edge_dim=edge_dim, fourier_encode_dist=fourier_encode_dist) self.norm = NormSE3(fiber_out) self.linear_out = LinearSE3(fiber_out, Fiber.create( output_degrees, 1)) if reduce_dim_out else None
def forward(self, features, edge_info, rel_dist, basis): h, attend_self = self.heads, self.attend_self device, dtype = get_tensor_device_and_dtype(features) neighbor_indices, neighbor_mask, edges = edge_info max_neg_value = -torch.finfo().max if exists(neighbor_mask): neighbor_mask = rearrange(neighbor_mask, 'b i j -> b () i j') neighbor_indices = rearrange(neighbor_indices, 'b i j -> b () i j') queries = self.to_q(features) keys, values = self.to_k(features, edge_info, rel_dist, basis), self.to_v(features, edge_info, rel_dist, basis) if attend_self: self_keys, self_values = self.to_self_k(features), self.to_self_v( features) outputs = {} for degree in features.keys(): q, k, v = map(lambda t: t[degree], (queries, keys, values)) q = rearrange(q, 'b i (h d) m -> b h i d m', h=h) k, v = map( lambda t: rearrange(t, 'b i j (h d) m -> b h i j d m', h=h), (k, v)) if self.use_null_kv: null_k, null_v = map(lambda t: t[degree], (self.null_keys, self.null_values)) null_k, null_v = map( lambda t: repeat( t, 'h d m -> b h i () d m', b=q.shape[0], i=q.shape[ 2]), (null_k, null_v)) k = torch.cat((null_k, k), dim=3) v = torch.cat((null_v, v), dim=3) if attend_self: self_k, self_v = map(lambda t: t[degree], (self_keys, self_values)) self_k, self_v = map( lambda t: rearrange(t, 'b n (h d) m -> b h n () d m', h=h), (self_k, self_v)) k = torch.cat((self_k, k), dim=3) v = torch.cat((self_v, v), dim=3) sim = einsum('b h i d m, b h i j d m -> b h i j', q, k) * self.scale if exists(neighbor_mask): num_left_pad = int(attend_self) + int(self.use_null_kv) mask = F.pad(neighbor_mask, (num_left_pad, 0), value=True) sim.masked_fill_(~mask, max_neg_value) attn = sim.softmax(dim=-1) out = einsum('b h i j, b h i j d m -> b h i d m', attn, v) outputs[degree] = rearrange(out, 'b h n d m -> b n (h d) m') return self.to_out(outputs)
def forward( self, feats, coors, mask = None, adj_mat = None, edges = None, return_type = None, return_pooled = False, neighbor_mask = None, global_feats = None ): assert not (self.accept_global_feats ^ exists(global_feats)), 'you cannot pass in global features unless you init the class correctly' _mask = mask if self.output_degrees == 1: return_type = 0 if exists(self.token_emb): feats = self.token_emb(feats) assert not (self.attend_sparse_neighbors and not exists(adj_mat)), 'adjacency matrix (adjacency_mat) or edges (edges) must be passed in' assert not (self.has_edges and not exists(edges)), 'edge embedding (num_edge_tokens & edge_dim) must be supplied if one were to train on edge types' if torch.is_tensor(feats): feats = {'0': feats[..., None]} if torch.is_tensor(global_feats): global_feats = {'0': global_feats[..., None]} b, n, d, *_, device = *feats['0'].shape, feats['0'].device assert d == self.dim_in[0], f'feature dimension {d} must be equal to dimension given at init {self.dim_in[0]}' assert set(map(int, feats.keys())) == set(range(self.input_degrees)), f'input must have {self.input_degrees} degree' num_degrees, neighbors, max_sparse_neighbors, valid_radius = self.num_degrees, self.num_neighbors, self.max_sparse_neighbors, self.valid_radius assert self.attend_sparse_neighbors or neighbors > 0, 'you must either attend to sparsely bonded neighbors, or set number of locally attended neighbors to be greater than 0' # se3 transformer by default cannot have a node attend to itself exclude_self_mask = rearrange(~torch.eye(n, dtype = torch.bool, device = device), 'i j -> () i j') remove_self = lambda t: t.masked_select(exclude_self_mask).reshape(b, n, n - 1) get_max_value = lambda t: torch.finfo(t.dtype).max # create N-degrees adjacent matrix from 1st degree connections if exists(self.num_adj_degrees): if len(adj_mat.shape) == 2: adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b = b) adj_indices = adj_mat.clone().long() for ind in range(self.num_adj_degrees - 1): degree = ind + 2 next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0 next_degree_mask = (next_degree_adj_mat.float() - adj_mat.float()).bool() adj_indices.masked_fill_(next_degree_mask, degree) adj_mat = next_degree_adj_mat.clone() adj_indices = adj_indices.masked_select(exclude_self_mask).reshape(b, n, n - 1) # calculate sparsely connected neighbors sparse_neighbor_mask = None num_sparse_neighbors = 0 if self.attend_sparse_neighbors: assert exists(adj_mat), 'adjacency matrix must be passed in (keyword argument adj_mat)' if exists(adj_mat): if len(adj_mat.shape) == 2: adj_mat = repeat(adj_mat, 'i j -> b i j', b = b) adj_mat = remove_self(adj_mat) adj_mat_values = adj_mat.float() adj_mat_max_neighbors = adj_mat_values.sum(dim = -1).max().item() if max_sparse_neighbors < adj_mat_max_neighbors: noise = torch.empty_like(adj_mat_values).uniform_(-0.01, 0.01) adj_mat_values += noise num_sparse_neighbors = int(min(max_sparse_neighbors, adj_mat_max_neighbors)) values, indices = adj_mat_values.topk(num_sparse_neighbors, dim = -1) sparse_neighbor_mask = torch.zeros_like(adj_mat_values).scatter_(-1, indices, values) sparse_neighbor_mask = sparse_neighbor_mask > 0.5 # exclude edge of token to itself indices = repeat(torch.arange(n, device = device), 'i -> b i j', b = b, j = n) rel_pos = rearrange(coors, 'b n d -> b n () d') - rearrange(coors, 'b n d -> b () n d') indices = indices.masked_select(exclude_self_mask).reshape(b, n, n - 1) rel_pos = rel_pos.masked_select(exclude_self_mask[..., None]).reshape(b, n, n - 1, 3) if exists(mask): mask = rearrange(mask, 'b i -> b i ()') * rearrange(mask, 'b j -> b () j') mask = mask.masked_select(exclude_self_mask).reshape(b, n, n - 1) if exists(edges): if exists(self.edge_emb): edges = self.edge_emb(edges) edges = edges.masked_select(exclude_self_mask[..., None]).reshape(b, n, n - 1, -1) if exists(self.adj_emb): adj_emb = self.adj_emb(adj_indices) edges = torch.cat((edges, adj_emb), dim = -1) if exists(edges) else adj_emb rel_dist = rel_pos.norm(dim = -1) # rel_dist gets modified using adjacency or neighbor mask modified_rel_dist = rel_dist.clone() max_value = get_max_value(modified_rel_dist) # for masking out nodes from being considered as neighbors # neighbors if exists(neighbor_mask): neighbor_mask = remove_self(neighbor_mask) max_neighbors = neighbor_mask.sum(dim = -1).max().item() if max_neighbors > neighbors: print(f'neighbor_mask shows maximum number of neighbors as {max_neighbors} but specified number of neighbors is {neighbors}') modified_rel_dist.masked_fill_(~neighbor_mask, max_value) # use sparse neighbor mask to assign priority of bonded if exists(sparse_neighbor_mask): modified_rel_dist.masked_fill_(sparse_neighbor_mask, 0.) # mask out future nodes to high distance if causal turned on if self.causal: causal_mask = torch.ones(n, n - 1, device = device).triu().bool() modified_rel_dist.masked_fill_(causal_mask[None, ...], max_value) # if number of local neighbors by distance is set to 0, then only fetch the sparse neighbors defined by adjacency matrix if neighbors == 0: valid_radius = 0 # get neighbors and neighbor mask, excluding self neighbors = int(min(neighbors, n - 1)) total_neighbors = int(neighbors + num_sparse_neighbors) assert total_neighbors > 0, 'you must be fetching at least 1 neighbor' total_neighbors = int(min(total_neighbors, n - 1)) # make sure total neighbors does not exceed the length of the sequence itself dist_values, nearest_indices = modified_rel_dist.topk(total_neighbors, dim = -1, largest = False) neighbor_mask = dist_values <= valid_radius neighbor_rel_dist = batched_index_select(rel_dist, nearest_indices, dim = 2) neighbor_rel_pos = batched_index_select(rel_pos, nearest_indices, dim = 2) neighbor_indices = batched_index_select(indices, nearest_indices, dim = 2) if exists(mask): neighbor_mask = neighbor_mask & batched_index_select(mask, nearest_indices, dim = 2) if exists(edges): edges = batched_index_select(edges, nearest_indices, dim = 2) # calculate rotary pos emb rotary_pos_emb = None rotary_query_pos_emb = None rotary_key_pos_emb = None if self.rotary_position: seq = torch.arange(n, device = device) seq_pos_emb = self.rotary_pos_emb(seq) self_indices = torch.arange(neighbor_indices.shape[1], device = device) self_indices = repeat(self_indices, 'i -> b i ()', b = b) neighbor_indices_with_self = torch.cat((self_indices, neighbor_indices), dim = 2) pos_emb = batched_index_select(seq_pos_emb, neighbor_indices_with_self, dim = 0) rotary_key_pos_emb = pos_emb rotary_query_pos_emb = repeat(seq_pos_emb, 'n d -> b n d', b = b) if self.rotary_rel_dist: neighbor_rel_dist_with_self = F.pad(neighbor_rel_dist, (1, 0), value = 0) * 1e2 rel_dist_pos_emb = self.rotary_pos_emb(neighbor_rel_dist_with_self) rotary_key_pos_emb = safe_cat(rotary_key_pos_emb, rel_dist_pos_emb, dim = -1) query_dist = torch.zeros(n, device = device) query_pos_emb = self.rotary_pos_emb(query_dist) query_pos_emb = repeat(query_pos_emb, 'n d -> b n d', b = b) rotary_query_pos_emb = safe_cat(rotary_query_pos_emb, query_pos_emb, dim = -1) if exists(rotary_query_pos_emb) and exists(rotary_key_pos_emb): rotary_pos_emb = (rotary_query_pos_emb, rotary_key_pos_emb) # calculate basis basis = get_basis(neighbor_rel_pos, num_degrees - 1, differentiable = self.differentiable_coors) # main logic edge_info = (neighbor_indices, neighbor_mask, edges) x = feats # project in x = self.conv_in(x, edge_info, rel_dist = neighbor_rel_dist, basis = basis) # preconvolution layers for conv, nonlin in self.convs: x = nonlin(x) x = conv(x, edge_info, rel_dist = neighbor_rel_dist, basis = basis) # transformer layers x = self.net(x, edge_info = edge_info, rel_dist = neighbor_rel_dist, basis = basis, global_feats = global_feats, pos_emb = rotary_pos_emb) # project out x = self.conv_out(x, edge_info, rel_dist = neighbor_rel_dist, basis = basis) # norm x = self.norm(x) # reduce dim if specified if exists(self.linear_out): x = self.linear_out(x) x = map_values(lambda t: t.squeeze(dim = 2), x) if return_pooled: mask_fn = (lambda t: masked_mean(t, _mask, dim = 1)) if exists(_mask) else (lambda t: t.mean(dim = 1)) x = map_values(mask_fn, x) if '0' in x: x['0'] = x['0'].squeeze(dim = -1) if exists(return_type): return x[str(return_type)] return x
def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos_emb = None): h, attend_self = self.heads, self.attend_self device, dtype = get_tensor_device_and_dtype(features) neighbor_indices, neighbor_mask, edges = edge_info max_neg_value = -torch.finfo().max if exists(neighbor_mask): neighbor_mask = rearrange(neighbor_mask, 'b i j -> b () i j') queries = self.to_q(features) values = self.to_v(features, edge_info, rel_dist, basis) if self.linear_proj_keys: keys = self.to_k(features) keys = map_values(lambda val: batched_index_select(val, neighbor_indices, dim = 1), keys) elif not exists(self.to_k): keys = values else: keys = self.to_k(features, edge_info, rel_dist, basis) if attend_self: self_keys, self_values = self.to_self_k(features), self.to_self_v(features) if exists(global_feats): global_keys, global_values = self.to_global_k(global_feats), self.to_global_v(global_feats) outputs = {} for degree in features.keys(): q, k, v = map(lambda t: t[degree], (queries, keys, values)) q = rearrange(q, 'b i (h d) m -> b h i d m', h = h) if attend_self: self_k, self_v = map(lambda t: t[degree], (self_keys, self_values)) self_k, self_v = map(lambda t: rearrange(t, 'b n d m -> b n () d m'), (self_k, self_v)) k = torch.cat((self_k, k), dim = 2) v = torch.cat((self_v, v), dim = 2) if exists(pos_emb) and degree == '0': query_pos_emb, key_pos_emb = pos_emb query_pos_emb = rearrange(query_pos_emb, 'b i d -> b () i d ()') key_pos_emb = rearrange(key_pos_emb, 'b i j d -> b i j d ()') q = apply_rotary_pos_emb(q, query_pos_emb) k = apply_rotary_pos_emb(k, key_pos_emb) v = apply_rotary_pos_emb(v, key_pos_emb) if self.use_null_kv: null_k, null_v = map(lambda t: t[degree], (self.null_keys, self.null_values)) null_k, null_v = map(lambda t: repeat(t, 'd m -> b i () d m', b = q.shape[0], i = q.shape[2]), (null_k, null_v)) k = torch.cat((null_k, k), dim = 2) v = torch.cat((null_v, v), dim = 2) if exists(global_feats) and degree == '0': global_k, global_v = map(lambda t: t[degree], (global_keys, global_values)) global_k, global_v = map(lambda t: repeat(t, 'b j d m -> b i j d m', i = k.shape[1]), (global_k, global_v)) k = torch.cat((global_k, k), dim = 2) v = torch.cat((global_v, v), dim = 2) sim = einsum('b h i d m, b i j d m -> b h i j', q, k) * self.scale if exists(neighbor_mask): num_left_pad = sim.shape[-1] - neighbor_mask.shape[-1] mask = F.pad(neighbor_mask, (num_left_pad, 0), value = True) sim.masked_fill_(~mask, max_neg_value) attn = sim.softmax(dim = -1) out = einsum('b h i j, b i j d m -> b h i d m', attn, v) outputs[degree] = rearrange(out, 'b h n d m -> b n (h d) m') return self.to_out(outputs)
def __init__( self, *, dim, heads = 8, dim_head = 24, depth = 2, input_degrees = 1, num_degrees = 2, output_degrees = 1, valid_radius = 1e5, reduce_dim_out = False, num_tokens = None, num_edge_tokens = None, edge_dim = None, reversible = False, attend_self = True, use_null_kv = False, differentiable_coors = False, fourier_encode_dist = False, rel_dist_num_fourier_features = 4, num_neighbors = float('inf'), attend_sparse_neighbors = False, num_adj_degrees = None, adj_dim = 0, max_sparse_neighbors = float('inf'), dim_in = None, dim_out = None, norm_out = False, num_conv_layers = 0, causal = False, splits = 4, global_feats_dim = None, linear_proj_keys = False, one_headed_key_values = False, tie_key_values = False, rotary_position = False, rotary_rel_dist = False ): super().__init__() dim_in = default(dim_in, dim) self.dim_in = cast_tuple(dim_in, input_degrees) self.dim = dim # token embedding self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None # positional embedding self.rotary_rel_dist = rotary_rel_dist self.rotary_position = rotary_position self.rotary_pos_emb = None if rotary_position or rotary_rel_dist: num_rotaries = int(rotary_position) + int(rotary_rel_dist) self.rotary_pos_emb = SinusoidalEmbeddings(dim_head // num_rotaries) # edges assert not (exists(num_edge_tokens) and not exists(edge_dim)), 'edge dimension (edge_dim) must be supplied if SE3 transformer is to have edge tokens' self.edge_emb = nn.Embedding(num_edge_tokens, edge_dim) if exists(num_edge_tokens) else None self.has_edges = exists(edge_dim) and edge_dim > 0 self.input_degrees = input_degrees self.num_degrees = num_degrees self.output_degrees = output_degrees # whether to differentiate through basis, needed for alphafold2 self.differentiable_coors = differentiable_coors # neighbors hyperparameters self.valid_radius = valid_radius self.num_neighbors = num_neighbors # sparse neighbors, derived from adjacency matrix or edges being passed in self.attend_sparse_neighbors = attend_sparse_neighbors self.max_sparse_neighbors = max_sparse_neighbors # adjacent neighbor derivation and embed assert not (exists(num_adj_degrees) and num_adj_degrees < 1), 'make sure adjacent degrees is greater than 1' self.num_adj_degrees = num_adj_degrees self.adj_emb = nn.Embedding(num_adj_degrees + 1, adj_dim) if exists(num_adj_degrees) and adj_dim > 0 else None edge_dim = (edge_dim if self.has_edges else 0) + (adj_dim if exists(self.adj_emb) else 0) # define fibers and dimensionality dim_in = default(dim_in, dim) dim_out = default(dim_out, dim) fiber_in = Fiber.create(input_degrees, dim_in) fiber_hidden = Fiber.create(num_degrees, dim) fiber_out = Fiber.create(output_degrees, dim_out) conv_kwargs = dict(edge_dim = edge_dim, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits) # causal assert not (causal and not attend_self), 'attending to self must be turned on if in autoregressive mode (for the first token)' self.causal = causal # main network self.conv_in = ConvSE3(fiber_in, fiber_hidden, **conv_kwargs) # pre-convs self.convs = nn.ModuleList([]) for _ in range(num_conv_layers): self.convs.append(nn.ModuleList([ ConvSE3(fiber_hidden, fiber_hidden, **conv_kwargs), NormSE3(fiber_hidden) ])) # global features self.accept_global_feats = exists(global_feats_dim) assert not (reversible and self.accept_global_feats), 'reversibility and global features are not compatible' # trunk self.attend_self = attend_self attention_klass = OneHeadedKVAttentionSE3 if one_headed_key_values else AttentionSE3 layers = nn.ModuleList([]) for _ in range(depth): layers.append(nn.ModuleList([ AttentionBlockSE3(fiber_hidden, heads = heads, dim_head = dim_head, attend_self = attend_self, edge_dim = edge_dim, fourier_encode_dist = fourier_encode_dist, rel_dist_num_fourier_features = rel_dist_num_fourier_features, use_null_kv = use_null_kv, splits = splits, global_feats_dim = global_feats_dim, linear_proj_keys = linear_proj_keys, attention_klass = attention_klass, tie_key_values = tie_key_values), FeedForwardBlockSE3(fiber_hidden) ])) execution_class = ReversibleSequence if reversible else SequentialSequence self.net = execution_class(layers) # out self.conv_out = ConvSE3(fiber_hidden, fiber_out, **conv_kwargs) self.norm = NormSE3(fiber_out, nonlin = nn.Identity()) if norm_out or reversible else nn.Identity() self.linear_out = LinearSE3( fiber_out, Fiber.create(output_degrees, 1) ) if reduce_dim_out else None
def forward( self, inp, edge_info, rel_dist = None, basis = None ): splits = self.splits neighbor_indices, neighbor_masks, edges = edge_info rel_dist = rearrange(rel_dist, 'b m n -> b m n ()') kernels = {} outputs = {} if self.fourier_encode_dist: rel_dist = fourier_encode(rel_dist[..., None], num_encodings = self.num_fourier_features) # split basis basis_keys = basis.keys() split_basis_values = list(zip(*list(map(lambda t: fast_split(t, splits, dim = 1), basis.values())))) split_basis = list(map(lambda v: dict(zip(basis_keys, v)), split_basis_values)) # go through every permutation of input degree type to output degree type for degree_out in self.fiber_out.degrees: output = 0 degree_out_key = str(degree_out) for degree_in, m_in in self.fiber_in: etype = f'({degree_in},{degree_out})' x = inp[str(degree_in)] x = batched_index_select(x, neighbor_indices, dim = 1) x = x.view(*x.shape[:3], to_order(degree_in) * m_in, 1) kernel_fn = self.kernel_unary[etype] edge_features = torch.cat((rel_dist, edges), dim = -1) if exists(edges) else rel_dist output_chunk = None split_x = fast_split(x, splits, dim = 1) split_edge_features = fast_split(edge_features, splits, dim = 1) # process input, edges, and basis in chunks along the sequence dimension for x_chunk, edge_features, basis in zip(split_x, split_edge_features, split_basis): kernel = kernel_fn(edge_features, basis = basis) chunk = einsum('... o i, ... i c -> ... o c', kernel, x_chunk) output_chunk = safe_cat(output_chunk, chunk, dim = 1) output = output + output_chunk if self.pool: output = masked_mean(output, neighbor_masks, dim = 2) if exists(neighbor_masks) else output.mean(dim = 2) leading_shape = x.shape[:2] if self.pool else x.shape[:3] output = output.view(*leading_shape, -1, to_order(degree_out)) outputs[degree_out_key] = output if self.self_interaction: self_interact_out = self.self_interact(inp) outputs = self.self_interact_sum(outputs, self_interact_out) return outputs
def __init__( self, fiber, dim_head = 64, heads = 8, attend_self = False, edge_dim = None, fourier_encode_dist = False, rel_dist_num_fourier_features = 4, use_null_kv = False, splits = 4, global_feats_dim = None, linear_proj_keys = False, tie_key_values = False ): super().__init__() hidden_dim = dim_head * heads hidden_fiber = Fiber(list(map(lambda t: (t[0], hidden_dim), fiber))) kv_hidden_fiber = Fiber(list(map(lambda t: (t[0], dim_head), fiber))) project_out = not (heads == 1 and len(fiber.dims) == 1 and dim_head == fiber.dims[0]) self.scale = dim_head ** -0.5 self.heads = heads self.linear_proj_keys = linear_proj_keys # whether to linearly project features for keys, rather than convolve with basis self.to_q = LinearSE3(fiber, hidden_fiber) self.to_v = ConvSE3(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits) assert not (linear_proj_keys and tie_key_values), 'you cannot do linear projection of keys and have shared key / values turned on at the same time' if linear_proj_keys: self.to_k = LinearSE3(fiber, kv_hidden_fiber) elif not tie_key_values: self.to_k = ConvSE3(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits) else: self.to_k = None self.to_out = LinearSE3(hidden_fiber, fiber) if project_out else nn.Identity() self.use_null_kv = use_null_kv if use_null_kv: self.null_keys = nn.ParameterDict() self.null_values = nn.ParameterDict() for degree in fiber.degrees: m = to_order(degree) degree_key = str(degree) self.null_keys[degree_key] = nn.Parameter(torch.zeros(dim_head, m)) self.null_values[degree_key] = nn.Parameter(torch.zeros(dim_head, m)) self.attend_self = attend_self if attend_self: self.to_self_k = LinearSE3(fiber, kv_hidden_fiber) self.to_self_v = LinearSE3(fiber, kv_hidden_fiber) self.accept_global_feats = exists(global_feats_dim) if self.accept_global_feats: global_input_fiber = Fiber.create(1, global_feats_dim) global_output_fiber = Fiber.create(1, kv_hidden_fiber[0]) self.to_global_k = LinearSE3(global_input_fiber, global_output_fiber) self.to_global_v = LinearSE3(global_input_fiber, global_output_fiber)
import os from math import pi import torch from torch import einsum from einops import rearrange from itertools import product from contextlib import contextmanager from se3_transformer_pytorch.irr_repr import irr_repr, spherical_harmonics from se3_transformer_pytorch.utils import torch_default_dtype, cache_dir, exists, to_order from se3_transformer_pytorch.spherical_harmonics import clear_spherical_harmonics_cache # constants CACHE_PATH = os.path.expanduser( '~/.cache.equivariant_attention') if not exists( os.environ.get('CLEAR_CACHE')) else None # todo (figure ot why this was hard coded in official repo) RANDOM_ANGLES = [[4.41301023, 5.56684102, 4.59384642], [4.93325116, 6.12697327, 4.14574096], [0.53878964, 4.09050444, 5.36539036], [2.16017393, 3.48835314, 5.55174441], [2.52385107, 0.2908958, 3.90040975]] # helpers @contextmanager def null_context(): yield
from math import pi import torch from torch import einsum from einops import rearrange from itertools import product from contextlib import contextmanager from se3_transformer_pytorch.irr_repr import irr_repr, spherical_harmonics from se3_transformer_pytorch.utils import torch_default_dtype, cache_dir, exists, default, to_order from se3_transformer_pytorch.spherical_harmonics import clear_spherical_harmonics_cache # constants CACHE_PATH = default(os.getenv('CACHE_PATH'), os.path.expanduser('~/.cache.equivariant_attention')) CACHE_PATH = CACHE_PATH if not exists(os.environ.get('CLEAR_CACHE')) else None # todo (figure ot why this was hard coded in official repo) RANDOM_ANGLES = [[4.41301023, 5.56684102, 4.59384642], [4.93325116, 6.12697327, 4.14574096], [0.53878964, 4.09050444, 5.36539036], [2.16017393, 3.48835314, 5.55174441], [2.52385107, 0.2908958, 3.90040975]] # helpers @contextmanager def null_context(): yield
def forward(self, feats, coors, mask=None, edges=None, return_type=None): if exists(self.token_emb): feats = self.token_emb(feats) assert not ( exists(edges) and not exists(self.edge_emb) ), 'edge embedding (num_edge_tokens & edge_dim) must be supplied if one were to train on edge types' if exists(edges): edges = self.edge_emb(edges) if torch.is_tensor(feats): feats = {'0': feats[..., None]} b, n, d, *_, device = *feats['0'].shape, feats['0'].device assert d == self.dim, f'feature dimension {d} must be equal to dimension given at init {self.dim}' assert set(map(int, feats.keys())) == set( range(self.input_degrees )), f'input must have {self.input_degrees} degree' num_degrees, neighbors = self.num_degrees, self.num_neighbors neighbors = min(neighbors, n - 1) # exclude edge of token to itself exclude_self_mask = rearrange( ~torch.eye(n, dtype=torch.bool, device=device), 'i j -> () i j') indices = repeat(torch.arange(n, device=device), 'i -> b i j', b=b, j=n) rel_pos = rearrange(coors, 'b n d -> b n () d') - rearrange( coors, 'b n d -> b () n d') indices = indices.masked_select(exclude_self_mask).reshape(b, n, n - 1) rel_pos = rel_pos.masked_select(exclude_self_mask[..., None]).reshape( b, n, n - 1, 3) if exists(mask): mask = rearrange(mask, 'b i -> b i ()') * rearrange( mask, 'b j -> b () j') mask = mask.masked_select(exclude_self_mask).reshape(b, n, n - 1) if exists(edges): edges = edges.masked_select(exclude_self_mask[..., None]).reshape( b, n, n - 1, -1) rel_dist = rel_pos.norm(dim=-1) # get neighbors and neighbor mask, excluding self neighbor_rel_dist, nearest_indices = rel_dist.topk(neighbors, dim=-1, largest=False) neighbor_rel_pos = batched_index_select(rel_pos, nearest_indices, dim=2) neighbor_indices = batched_index_select(indices, nearest_indices, dim=2) basis = get_basis(neighbor_rel_pos, num_degrees - 1, differentiable=self.differentiable_coors) neighbor_mask = neighbor_rel_dist <= self.valid_radius if exists(mask): neighbor_mask = neighbor_mask & batched_index_select( mask, nearest_indices, dim=2) if exists(edges): edges = batched_index_select(edges, nearest_indices, dim=2) # main logic edge_info = (neighbor_indices, neighbor_mask, edges) x = feats # project in x = self.conv_in(x, edge_info, rel_dist=neighbor_rel_dist, basis=basis) # transformer layers x = self.net(x, edge_info=edge_info, rel_dist=neighbor_rel_dist, basis=basis) # project out x = self.conv_out(x, edge_info, rel_dist=neighbor_rel_dist, basis=basis) # norm x = self.norm(x) # reduce dim if specified if exists(self.linear_out): x = self.linear_out(x) x = {k: v.squeeze(dim=2) for k, v in x.items()} if exists(return_type): return x[str(return_type)] return x
def __init__(self, *, dim, num_neighbors=12, heads=8, dim_head=64, depth=2, num_degrees=2, input_degrees=1, output_degrees=2, valid_radius=1e5, reduce_dim_out=False, num_tokens=None, num_edge_tokens=None, edge_dim=None, reversible=False, attend_self=False, use_null_kv=False, differentiable_coors=False): super().__init__() assert num_neighbors > 0, 'neighbors must be at least 1' self.dim = dim self.valid_radius = valid_radius self.token_emb = None self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None assert not ( exists(num_edge_tokens) and not exists(edge_dim) ), 'edge dimension (edge_dim) must be supplied if SE3 transformer is to have edge tokens' self.edge_emb = nn.Embedding( num_edge_tokens, edge_dim) if exists(num_edge_tokens) else None self.input_degrees = input_degrees self.num_degrees = num_degrees self.num_neighbors = num_neighbors fiber_in = Fiber.create(input_degrees, dim) fiber_hidden = Fiber.create(num_degrees, dim) fiber_out = Fiber.create(output_degrees, dim) self.conv_in = ConvSE3(fiber_in, fiber_hidden, edge_dim=edge_dim) layers = nn.ModuleList([]) for _ in range(depth): layers.append( nn.ModuleList([ AttentionBlockSE3(fiber_hidden, heads=heads, dim_head=dim_head, attend_self=attend_self, edge_dim=edge_dim, use_null_kv=use_null_kv), FeedForwardBlockSE3(fiber_hidden) ])) execution_class = ReversibleSequence if reversible else SequentialSequence self.net = execution_class(layers) self.conv_out = ConvSE3(fiber_hidden, fiber_out, edge_dim=edge_dim) self.norm = NormSE3(fiber_out) self.linear_out = LinearSE3(fiber_out, Fiber.create( output_degrees, 1)) if reduce_dim_out else None self.differentiable_coors = differentiable_coors
def __init__(self, fiber, dim_head=64, heads=8, attend_self=False, edge_dim=None, fourier_encode_dist=False, rel_dist_num_fourier_features=4, use_null_kv=False, splits=4, global_feats_dim=None): super().__init__() hidden_dim = dim_head * heads hidden_fiber = Fiber(list(map(lambda t: (t[0], hidden_dim), fiber))) project_out = not (heads == 1 and len(fiber.dims) == 1 and dim_head == fiber.dims[0]) self.scale = dim_head**-0.5 self.heads = heads self.to_q = LinearSE3(fiber, hidden_fiber) self.to_k = ConvSE3(fiber, hidden_fiber, edge_dim=edge_dim, pool=False, self_interaction=False, fourier_encode_dist=fourier_encode_dist, num_fourier_features=rel_dist_num_fourier_features, splits=splits) self.to_v = ConvSE3(fiber, hidden_fiber, edge_dim=edge_dim, pool=False, self_interaction=False, fourier_encode_dist=fourier_encode_dist, num_fourier_features=rel_dist_num_fourier_features, splits=splits) self.to_out = LinearSE3(hidden_fiber, fiber) if project_out else nn.Identity() self.use_null_kv = use_null_kv if use_null_kv: self.null_keys = nn.ParameterDict() self.null_values = nn.ParameterDict() for degree in fiber.degrees: m = to_order(degree) degree_key = str(degree) self.null_keys[degree_key] = nn.Parameter( torch.zeros(heads, dim_head, m)) self.null_values[degree_key] = nn.Parameter( torch.zeros(heads, dim_head, m)) self.attend_self = attend_self if attend_self: self.to_self_k = LinearSE3(fiber, hidden_fiber) self.to_self_v = LinearSE3(fiber, hidden_fiber) self.accept_global_feats = exists(global_feats_dim) if self.accept_global_feats: global_input_fiber = Fiber.create(1, global_feats_dim) global_output_fiber = Fiber.create(1, hidden_fiber[0]) self.to_global_k = LinearSE3(global_input_fiber, global_output_fiber) self.to_global_v = LinearSE3(global_input_fiber, global_output_fiber)