Ejemplo n.º 1
0
    def forward(
            self,
            value: Union[Tensor, Dict[str,
                                      Tensor]],  # edge features (may be fused)
            key: Union[Tensor, Dict[str,
                                    Tensor]],  # edge features (may be fused)
            query: Dict[str, Tensor],  # node features
            graph: DGLGraph):
        with nvtx_range('AttentionSE3'):
            with nvtx_range('reshape keys and queries'):
                if isinstance(key, Tensor):
                    # case where features of all types are fused
                    key = key.reshape(key.shape[0], self.num_heads, -1)
                    # need to reshape queries that way to keep the same layout as keys
                    out = torch.cat(
                        [query[str(d)] for d in self.key_fiber.degrees],
                        dim=-1)
                    query = out.reshape(
                        list(query.values())[0].shape[0], self.num_heads, -1)
                else:
                    # features are not fused, need to fuse and reshape them
                    key = self.key_fiber.to_attention_heads(
                        key, self.num_heads)
                    query = self.key_fiber.to_attention_heads(
                        query, self.num_heads)

            with nvtx_range('attention dot product + softmax'):
                # Compute attention weights (softmax of inner product between key and query)
                edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1)
                edge_weights = edge_weights / np.sqrt(
                    self.key_fiber.num_features)
                edge_weights = edge_softmax(graph, edge_weights)
                edge_weights = edge_weights[..., None, None]

            with nvtx_range('weighted sum'):
                if isinstance(value, Tensor):
                    # features of all types are fused
                    v = value.view(value.shape[0], self.num_heads, -1,
                                   value.shape[-1])
                    weights = edge_weights * v
                    feat_out = dgl.ops.copy_e_sum(graph, weights)
                    feat_out = feat_out.view(feat_out.shape[0], -1,
                                             feat_out.shape[-1])  # merge heads
                    out = unfuse_features(feat_out, self.value_fiber.degrees)
                else:
                    out = {}
                    for degree, channels in self.value_fiber:
                        v = value[str(degree)].view(-1, self.num_heads,
                                                    channels // self.num_heads,
                                                    degree_to_dim(degree))
                        weights = edge_weights * v
                        res = dgl.ops.copy_e_sum(graph, weights)
                        out[str(degree)] = res.view(
                            -1, channels, degree_to_dim(degree))  # merge heads

                return out
Ejemplo n.º 2
0
    def forward(self, features: Dict[str, Tensor], *args,
                **kwargs) -> Dict[str, Tensor]:
        with nvtx_range('NormSE3'):
            output = {}
            if hasattr(self, 'group_norm'):
                # Compute per-degree norms of features
                norms = [
                    features[str(d)].norm(
                        dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
                    for d in self.fiber.degrees
                ]
                fused_norms = torch.cat(norms, dim=-2)

                # Transform the norms only
                new_norms = self.nonlinearity(
                    self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1)
                new_norms = torch.chunk(new_norms,
                                        chunks=len(self.fiber.degrees),
                                        dim=-2)

                # Scale features to the new norms
                for norm, new_norm, d in zip(norms, new_norms,
                                             self.fiber.degrees):
                    output[str(d)] = features[str(d)] / norm * new_norm
            else:
                for degree, feat in features.items():
                    norm = feat.norm(dim=-1,
                                     keepdim=True).clamp(min=self.NORM_CLAMP)
                    new_norm = self.nonlinearity(self.layer_norms[degree](
                        norm.squeeze(-1)).unsqueeze(-1))
                    output[degree] = new_norm * feat / norm

            return output
Ejemplo n.º 3
0
    def forward(self, node_features: Dict[str, Tensor],
                edge_features: Dict[str, Tensor], graph: DGLGraph,
                basis: Dict[str, Tensor]):
        with nvtx_range('AttentionBlockSE3'):
            with nvtx_range('keys / values'):
                fused_key_value = self.to_key_value(node_features,
                                                    edge_features, graph,
                                                    basis)
                key, value = self._get_key_value_from_fused(fused_key_value)

            with nvtx_range('queries'):
                query = self.to_query(node_features)

            z = self.attention(value, key, query, graph)
            z_concat = aggregate_residual(node_features, z, 'cat')
            return self.project(z_concat)
Ejemplo n.º 4
0
    def forward(self, features: Tensor, invariant_edge_feats: Tensor,
                basis: Tensor):
        with nvtx_range(f'VersatileConvSE3'):
            num_edges = features.shape[0]
            in_dim = features.shape[2]
            with nvtx_range(f'RadialProfile'):
                radial_weights = self.radial_func(invariant_edge_feats) \
                    .view(-1, self.channels_out, self.channels_in * self.freq_sum)

            if basis is not None:
                # This block performs the einsum n i l, n o i f, n l f k -> n o k
                basis_view = basis.view(num_edges, in_dim, -1)
                tmp = (features @ basis_view).view(num_edges, -1,
                                                   basis.shape[-1])
                return radial_weights @ tmp
            else:
                # k = l = 0 non-fused case
                return radial_weights @ features
Ejemplo n.º 5
0
def get_basis(relative_pos: Tensor,
              max_degree: int = 4,
              compute_gradients: bool = False,
              use_pad_trick: bool = False,
              amp: bool = False) -> Dict[str, Tensor]:
    with nvtx_range('spherical harmonics'):
        spherical_harmonics = get_spherical_harmonics(relative_pos, max_degree)
    with nvtx_range('CB coefficients'):
        clebsch_gordon = get_all_clebsch_gordon(max_degree,
                                                relative_pos.device)

    with torch.autograd.set_grad_enabled(compute_gradients):
        with nvtx_range('bases'):
            basis = get_basis_script(max_degree=max_degree,
                                     use_pad_trick=use_pad_trick,
                                     spherical_harmonics=spherical_harmonics,
                                     clebsch_gordon=clebsch_gordon,
                                     amp=amp)
            return basis
Ejemplo n.º 6
0
    def forward(self, node_feats: Dict[str, Tensor], edge_feats: Dict[str,
                                                                      Tensor],
                graph: DGLGraph, basis: Dict[str, Tensor]):
        with nvtx_range(f'ConvSE3'):
            invariant_edge_feats = edge_feats['0'].squeeze(-1)
            src, dst = graph.edges()
            out = {}
            in_features = []

            # Fetch all input features from edge and node features
            for degree_in in self.fiber_in.degrees:
                src_node_features = node_feats[str(degree_in)][src]
                if degree_in > 0 and str(degree_in) in edge_feats:
                    # Handle edge features of any type by concatenating them to node features
                    src_node_features = torch.cat(
                        [src_node_features, edge_feats[str(degree_in)]], dim=1)
                in_features.append(src_node_features)

            if self.used_fuse_level == ConvSE3FuseLevel.FULL:
                in_features_fused = torch.cat(in_features, dim=-1)
                out = self.conv(in_features_fused, invariant_edge_feats,
                                basis['fully_fused'])

                if not self.allow_fused_output or self.self_interaction or self.pool:
                    out = unfuse_features(out, self.fiber_out.degrees)

            elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(
                    self, 'conv_out'):
                in_features_fused = torch.cat(in_features, dim=-1)
                for degree_out in self.fiber_out.degrees:
                    basis_used = basis[f'out{degree_out}_fused']
                    out[str(degree_out)] = self._try_unpad(
                        self.conv_out[str(degree_out)](in_features_fused,
                                                       invariant_edge_feats,
                                                       basis_used), basis_used)

            elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(
                    self, 'conv_in'):
                out = 0
                for degree_in, feature in zip(self.fiber_in.degrees,
                                              in_features):
                    out = out + self.conv_in[str(degree_in)](
                        feature, invariant_edge_feats,
                        basis[f'in{degree_in}_fused'])
                if not self.allow_fused_output or self.self_interaction or self.pool:
                    out = unfuse_features(out, self.fiber_out.degrees)
            else:
                # Fallback to pairwise TFN convolutions
                for degree_out in self.fiber_out.degrees:
                    out_feature = 0
                    for degree_in, feature in zip(self.fiber_in.degrees,
                                                  in_features):
                        dict_key = f'{degree_in},{degree_out}'
                        basis_used = basis.get(dict_key, None)
                        out_feature = out_feature + self._try_unpad(
                            self.conv[dict_key](feature, invariant_edge_feats,
                                                basis_used), basis_used)
                    out[str(degree_out)] = out_feature

            for degree_out in self.fiber_out.degrees:
                if self.self_interaction and str(
                        degree_out) in self.to_kernel_self:
                    with nvtx_range(f'self interaction'):
                        dst_features = node_feats[str(degree_out)][dst]
                        kernel_self = self.to_kernel_self[str(degree_out)]
                        out[str(degree_out)] = out[str(
                            degree_out)] + kernel_self @ dst_features

                if self.pool:
                    with nvtx_range(f'pooling'):
                        if isinstance(out, dict):
                            out[str(degree_out)] = dgl.ops.copy_e_sum(
                                graph, out[str(degree_out)])
                        else:
                            out = dgl.ops.copy_e_sum(graph, out)
            return out
Ejemplo n.º 7
0
def get_spherical_harmonics(relative_pos: Tensor,
                            max_degree: int) -> List[Tensor]:
    all_degrees = list(range(2 * max_degree + 1))
    with nvtx_range('spherical harmonics'):
        sh = o3.spherical_harmonics(all_degrees, relative_pos, normalize=True)
        return torch.split(sh, [degree_to_dim(d) for d in all_degrees], dim=1)