Ejemplo n.º 1
0
    def forward(
            self,
            value: Union[Tensor, Dict[str,
                                      Tensor]],  # edge features (may be fused)
            key: Union[Tensor, Dict[str,
                                    Tensor]],  # edge features (may be fused)
            query: Dict[str, Tensor],  # node features
            graph: DGLGraph):
        with nvtx_range('AttentionSE3'):
            with nvtx_range('reshape keys and queries'):
                if isinstance(key, Tensor):
                    # case where features of all types are fused
                    key = key.reshape(key.shape[0], self.num_heads, -1)
                    # need to reshape queries that way to keep the same layout as keys
                    out = torch.cat(
                        [query[str(d)] for d in self.key_fiber.degrees],
                        dim=-1)
                    query = out.reshape(
                        list(query.values())[0].shape[0], self.num_heads, -1)
                else:
                    # features are not fused, need to fuse and reshape them
                    key = self.key_fiber.to_attention_heads(
                        key, self.num_heads)
                    query = self.key_fiber.to_attention_heads(
                        query, self.num_heads)

            with nvtx_range('attention dot product + softmax'):
                # Compute attention weights (softmax of inner product between key and query)
                edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1)
                edge_weights = edge_weights / np.sqrt(
                    self.key_fiber.num_features)
                edge_weights = edge_softmax(graph, edge_weights)
                edge_weights = edge_weights[..., None, None]

            with nvtx_range('weighted sum'):
                if isinstance(value, Tensor):
                    # features of all types are fused
                    v = value.view(value.shape[0], self.num_heads, -1,
                                   value.shape[-1])
                    weights = edge_weights * v
                    feat_out = dgl.ops.copy_e_sum(graph, weights)
                    feat_out = feat_out.view(feat_out.shape[0], -1,
                                             feat_out.shape[-1])  # merge heads
                    out = unfuse_features(feat_out, self.value_fiber.degrees)
                else:
                    out = {}
                    for degree, channels in self.value_fiber:
                        v = value[str(degree)].view(-1, self.num_heads,
                                                    channels // self.num_heads,
                                                    degree_to_dim(degree))
                        weights = edge_weights * v
                        res = dgl.ops.copy_e_sum(graph, weights)
                        out[str(degree)] = res.view(
                            -1, channels, degree_to_dim(degree))  # merge heads

                return out
Ejemplo n.º 2
0
 def from_features(feats: Dict[str, Tensor]):
     """ Infer the Fiber structure from a feature dict """
     structure = {}
     for k, v in feats.items():
         degree = int(k)
         assert len(v.shape) == 3, 'Feature shape should be (N, C, 2D+1)'
         assert v.shape[-1] == degree_to_dim(degree)
         structure[degree] = v.shape[-2]
     return Fiber(structure)
Ejemplo n.º 3
0
 def num_features(self):
     """ Size of the resulting tensor if all features were concatenated together """
     return sum(t.channels * degree_to_dim(t.degree)
                for t in self.structure)
Ejemplo n.º 4
0
    def __init__(self,
                 fiber_in: Fiber,
                 fiber_out: Fiber,
                 fiber_edge: Fiber,
                 pool: bool = True,
                 use_layer_norm: bool = False,
                 self_interaction: bool = False,
                 max_degree: int = 4,
                 fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
                 allow_fused_output: bool = False):
        """
        :param fiber_in:           Fiber describing the input features
        :param fiber_out:          Fiber describing the output features
        :param fiber_edge:         Fiber describing the edge features (node distances excluded)
        :param pool:               If True, compute final node features by averaging incoming edge features
        :param use_layer_norm:     Apply layer normalization between MLP layers
        :param self_interaction:   Apply self-interaction of nodes
        :param max_degree:         Maximum degree used in the bases computation
        :param fuse_level:         Maximum fuse level to use in TFN convolutions
        :param allow_fused_output: Allow the module to output a fused representation of features
        """
        super().__init__()
        self.pool = pool
        self.fiber_in = fiber_in
        self.fiber_out = fiber_out
        self.self_interaction = self_interaction
        self.max_degree = max_degree
        self.allow_fused_output = allow_fused_output

        # channels_in: account for the concatenation of edge features
        channels_in_set = set([
            f.channels + fiber_edge[f.degree] * (f.degree > 0)
            for f in self.fiber_in
        ])
        channels_out_set = set([f.channels for f in self.fiber_out])
        unique_channels_in = (len(channels_in_set) == 1)
        unique_channels_out = (len(channels_out_set) == 1)
        degrees_up_to_max = list(range(max_degree + 1))
        common_args = dict(edge_dim=fiber_edge[0] + 1,
                           use_layer_norm=use_layer_norm)

        if fuse_level.value >= ConvSE3FuseLevel.FULL.value and \
                unique_channels_in and fiber_in.degrees == degrees_up_to_max and \
                unique_channels_out and fiber_out.degrees == degrees_up_to_max:
            # Single fused convolution
            self.used_fuse_level = ConvSE3FuseLevel.FULL

            sum_freq = sum([
                degree_to_dim(min(d_in, d_out)) for d_in, d_out in product(
                    degrees_up_to_max, degrees_up_to_max)
            ])

            self.conv = VersatileConvSE3(sum_freq,
                                         list(channels_in_set)[0],
                                         list(channels_out_set)[0],
                                         fuse_level=self.used_fuse_level,
                                         **common_args)

        elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
                unique_channels_in and fiber_in.degrees == degrees_up_to_max:
            # Convolutions fused per output degree
            self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
            self.conv_out = nn.ModuleDict()
            for d_out, c_out in fiber_out:
                sum_freq = sum(
                    [degree_to_dim(min(d_out, d)) for d in fiber_in.degrees])
                self.conv_out[str(d_out)] = VersatileConvSE3(
                    sum_freq,
                    list(channels_in_set)[0],
                    c_out,
                    fuse_level=self.used_fuse_level,
                    **common_args)

        elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
                unique_channels_out and fiber_out.degrees == degrees_up_to_max:
            # Convolutions fused per input degree
            self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
            self.conv_in = nn.ModuleDict()
            for d_in, c_in in fiber_in:
                sum_freq = sum(
                    [degree_to_dim(min(d_in, d)) for d in fiber_out.degrees])
                channels_in_new = c_in + fiber_edge[d_in] * (d_in > 0)
                self.conv_in[str(d_in)] = VersatileConvSE3(
                    sum_freq,
                    channels_in_new,
                    list(channels_out_set)[0],
                    fuse_level=self.used_fuse_level,
                    **common_args)
        else:
            # Use pairwise TFN convolutions
            self.used_fuse_level = ConvSE3FuseLevel.NONE
            self.conv = nn.ModuleDict()
            for (degree_in, channels_in), (degree_out,
                                           channels_out) in (self.fiber_in *
                                                             self.fiber_out):
                dict_key = f'{degree_in},{degree_out}'
                channels_in_new = channels_in + fiber_edge[degree_in] * (
                    degree_in > 0)
                sum_freq = degree_to_dim(min(degree_in, degree_out))
                self.conv[dict_key] = VersatileConvSE3(
                    sum_freq,
                    channels_in_new,
                    channels_out,
                    fuse_level=self.used_fuse_level,
                    **common_args)

        if self_interaction:
            self.to_kernel_self = nn.ParameterDict()
            for degree_out, channels_out in fiber_out:
                if fiber_in[degree_out]:
                    self.to_kernel_self[str(degree_out)] = nn.Parameter(
                        torch.randn(channels_out, fiber_in[degree_out]) /
                        np.sqrt(fiber_in[degree_out]))
Ejemplo n.º 5
0
def get_spherical_harmonics(relative_pos: Tensor,
                            max_degree: int) -> List[Tensor]:
    all_degrees = list(range(2 * max_degree + 1))
    with nvtx_range('spherical harmonics'):
        sh = o3.spherical_harmonics(all_degrees, relative_pos, normalize=True)
        return torch.split(sh, [degree_to_dim(d) for d in all_degrees], dim=1)
Ejemplo n.º 6
0
def update_basis_with_fused(basis: Dict[str, Tensor], max_degree: int,
                            use_pad_trick: bool,
                            fully_fused: bool) -> Dict[str, Tensor]:
    """ Update the basis dict with partially and optionally fully fused bases """
    num_edges = basis['0,0'].shape[0]
    device = basis['0,0'].device
    dtype = basis['0,0'].dtype
    sum_dim = sum([degree_to_dim(d) for d in range(max_degree + 1)])

    # Fused per output degree
    for d_out in range(max_degree + 1):
        sum_freq = sum(
            [degree_to_dim(min(d, d_out)) for d in range(max_degree + 1)])
        basis_fused = torch.zeros(num_edges,
                                  sum_dim,
                                  sum_freq,
                                  degree_to_dim(d_out) + int(use_pad_trick),
                                  device=device,
                                  dtype=dtype)
        acc_d, acc_f = 0, 0
        for d_in in range(max_degree + 1):
            basis_fused[:, acc_d:acc_d + degree_to_dim(d_in),
                        acc_f:acc_f + degree_to_dim(min(d_out, d_in)), :
                        degree_to_dim(d_out)] = basis[
                            f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)]

            acc_d += degree_to_dim(d_in)
            acc_f += degree_to_dim(min(d_out, d_in))

        basis[f'out{d_out}_fused'] = basis_fused

    # Fused per input degree
    for d_in in range(max_degree + 1):
        sum_freq = sum(
            [degree_to_dim(min(d, d_in)) for d in range(max_degree + 1)])
        basis_fused = torch.zeros(num_edges,
                                  degree_to_dim(d_in),
                                  sum_freq,
                                  sum_dim,
                                  device=device,
                                  dtype=dtype)
        acc_d, acc_f = 0, 0
        for d_out in range(max_degree + 1):
            basis_fused[:, :, acc_f:acc_f + degree_to_dim(min(d_out, d_in)), acc_d:acc_d + degree_to_dim(d_out)] \
                = basis[f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)]

            acc_d += degree_to_dim(d_out)
            acc_f += degree_to_dim(min(d_out, d_in))

        basis[f'in{d_in}_fused'] = basis_fused

    if fully_fused:
        # Fully fused
        # Double sum this way because of JIT script
        sum_freq = sum([
            sum([
                degree_to_dim(min(d_in, d_out))
                for d_in in range(max_degree + 1)
            ]) for d_out in range(max_degree + 1)
        ])
        basis_fused = torch.zeros(num_edges,
                                  sum_dim,
                                  sum_freq,
                                  sum_dim,
                                  device=device,
                                  dtype=dtype)

        acc_d, acc_f = 0, 0
        for d_out in range(max_degree + 1):
            b = basis[f'out{d_out}_fused']
            basis_fused[:, :, acc_f:acc_f + b.shape[2], acc_d:acc_d +
                        degree_to_dim(d_out)] = b[:, :, :, :degree_to_dim(d_out
                                                                          )]
            acc_f += b.shape[2]
            acc_d += degree_to_dim(d_out)

        basis['fully_fused'] = basis_fused

    del basis[
        '0,0']  # We know that the basis for l = k = 0 is filled with a constant
    return basis