def forward( self, value: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused) key: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused) query: Dict[str, Tensor], # node features graph: DGLGraph): with nvtx_range('AttentionSE3'): with nvtx_range('reshape keys and queries'): if isinstance(key, Tensor): # case where features of all types are fused key = key.reshape(key.shape[0], self.num_heads, -1) # need to reshape queries that way to keep the same layout as keys out = torch.cat( [query[str(d)] for d in self.key_fiber.degrees], dim=-1) query = out.reshape( list(query.values())[0].shape[0], self.num_heads, -1) else: # features are not fused, need to fuse and reshape them key = self.key_fiber.to_attention_heads( key, self.num_heads) query = self.key_fiber.to_attention_heads( query, self.num_heads) with nvtx_range('attention dot product + softmax'): # Compute attention weights (softmax of inner product between key and query) edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1) edge_weights = edge_weights / np.sqrt( self.key_fiber.num_features) edge_weights = edge_softmax(graph, edge_weights) edge_weights = edge_weights[..., None, None] with nvtx_range('weighted sum'): if isinstance(value, Tensor): # features of all types are fused v = value.view(value.shape[0], self.num_heads, -1, value.shape[-1]) weights = edge_weights * v feat_out = dgl.ops.copy_e_sum(graph, weights) feat_out = feat_out.view(feat_out.shape[0], -1, feat_out.shape[-1]) # merge heads out = unfuse_features(feat_out, self.value_fiber.degrees) else: out = {} for degree, channels in self.value_fiber: v = value[str(degree)].view(-1, self.num_heads, channels // self.num_heads, degree_to_dim(degree)) weights = edge_weights * v res = dgl.ops.copy_e_sum(graph, weights) out[str(degree)] = res.view( -1, channels, degree_to_dim(degree)) # merge heads return out
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]: with nvtx_range('NormSE3'): output = {} if hasattr(self, 'group_norm'): # Compute per-degree norms of features norms = [ features[str(d)].norm( dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP) for d in self.fiber.degrees ] fused_norms = torch.cat(norms, dim=-2) # Transform the norms only new_norms = self.nonlinearity( self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1) new_norms = torch.chunk(new_norms, chunks=len(self.fiber.degrees), dim=-2) # Scale features to the new norms for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees): output[str(d)] = features[str(d)] / norm * new_norm else: for degree, feat in features.items(): norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP) new_norm = self.nonlinearity(self.layer_norms[degree]( norm.squeeze(-1)).unsqueeze(-1)) output[degree] = new_norm * feat / norm return output
def forward(self, node_features: Dict[str, Tensor], edge_features: Dict[str, Tensor], graph: DGLGraph, basis: Dict[str, Tensor]): with nvtx_range('AttentionBlockSE3'): with nvtx_range('keys / values'): fused_key_value = self.to_key_value(node_features, edge_features, graph, basis) key, value = self._get_key_value_from_fused(fused_key_value) with nvtx_range('queries'): query = self.to_query(node_features) z = self.attention(value, key, query, graph) z_concat = aggregate_residual(node_features, z, 'cat') return self.project(z_concat)
def forward(self, features: Tensor, invariant_edge_feats: Tensor, basis: Tensor): with nvtx_range(f'VersatileConvSE3'): num_edges = features.shape[0] in_dim = features.shape[2] with nvtx_range(f'RadialProfile'): radial_weights = self.radial_func(invariant_edge_feats) \ .view(-1, self.channels_out, self.channels_in * self.freq_sum) if basis is not None: # This block performs the einsum n i l, n o i f, n l f k -> n o k basis_view = basis.view(num_edges, in_dim, -1) tmp = (features @ basis_view).view(num_edges, -1, basis.shape[-1]) return radial_weights @ tmp else: # k = l = 0 non-fused case return radial_weights @ features
def get_basis(relative_pos: Tensor, max_degree: int = 4, compute_gradients: bool = False, use_pad_trick: bool = False, amp: bool = False) -> Dict[str, Tensor]: with nvtx_range('spherical harmonics'): spherical_harmonics = get_spherical_harmonics(relative_pos, max_degree) with nvtx_range('CB coefficients'): clebsch_gordon = get_all_clebsch_gordon(max_degree, relative_pos.device) with torch.autograd.set_grad_enabled(compute_gradients): with nvtx_range('bases'): basis = get_basis_script(max_degree=max_degree, use_pad_trick=use_pad_trick, spherical_harmonics=spherical_harmonics, clebsch_gordon=clebsch_gordon, amp=amp) return basis
def forward(self, node_feats: Dict[str, Tensor], edge_feats: Dict[str, Tensor], graph: DGLGraph, basis: Dict[str, Tensor]): with nvtx_range(f'ConvSE3'): invariant_edge_feats = edge_feats['0'].squeeze(-1) src, dst = graph.edges() out = {} in_features = [] # Fetch all input features from edge and node features for degree_in in self.fiber_in.degrees: src_node_features = node_feats[str(degree_in)][src] if degree_in > 0 and str(degree_in) in edge_feats: # Handle edge features of any type by concatenating them to node features src_node_features = torch.cat( [src_node_features, edge_feats[str(degree_in)]], dim=1) in_features.append(src_node_features) if self.used_fuse_level == ConvSE3FuseLevel.FULL: in_features_fused = torch.cat(in_features, dim=-1) out = self.conv(in_features_fused, invariant_edge_feats, basis['fully_fused']) if not self.allow_fused_output or self.self_interaction or self.pool: out = unfuse_features(out, self.fiber_out.degrees) elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr( self, 'conv_out'): in_features_fused = torch.cat(in_features, dim=-1) for degree_out in self.fiber_out.degrees: basis_used = basis[f'out{degree_out}_fused'] out[str(degree_out)] = self._try_unpad( self.conv_out[str(degree_out)](in_features_fused, invariant_edge_feats, basis_used), basis_used) elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr( self, 'conv_in'): out = 0 for degree_in, feature in zip(self.fiber_in.degrees, in_features): out = out + self.conv_in[str(degree_in)]( feature, invariant_edge_feats, basis[f'in{degree_in}_fused']) if not self.allow_fused_output or self.self_interaction or self.pool: out = unfuse_features(out, self.fiber_out.degrees) else: # Fallback to pairwise TFN convolutions for degree_out in self.fiber_out.degrees: out_feature = 0 for degree_in, feature in zip(self.fiber_in.degrees, in_features): dict_key = f'{degree_in},{degree_out}' basis_used = basis.get(dict_key, None) out_feature = out_feature + self._try_unpad( self.conv[dict_key](feature, invariant_edge_feats, basis_used), basis_used) out[str(degree_out)] = out_feature for degree_out in self.fiber_out.degrees: if self.self_interaction and str( degree_out) in self.to_kernel_self: with nvtx_range(f'self interaction'): dst_features = node_feats[str(degree_out)][dst] kernel_self = self.to_kernel_self[str(degree_out)] out[str(degree_out)] = out[str( degree_out)] + kernel_self @ dst_features if self.pool: with nvtx_range(f'pooling'): if isinstance(out, dict): out[str(degree_out)] = dgl.ops.copy_e_sum( graph, out[str(degree_out)]) else: out = dgl.ops.copy_e_sum(graph, out) return out
def get_spherical_harmonics(relative_pos: Tensor, max_degree: int) -> List[Tensor]: all_degrees = list(range(2 * max_degree + 1)) with nvtx_range('spherical harmonics'): sh = o3.spherical_harmonics(all_degrees, relative_pos, normalize=True) return torch.split(sh, [degree_to_dim(d) for d in all_degrees], dim=1)