예제 #1
0
    def forward(self, inp, edge_info, rel_dist=None, basis=None):
        neighbor_indices, neighbor_masks, edges = edge_info
        rel_dist = rearrange(rel_dist, 'b m n -> b m n ()')

        kernels = {}
        outputs = {}

        for (di, mi), (do, mo) in (self.fiber_in * self.fiber_out):
            etype = f'({di},{do})'
            kernel_fn = self.kernel_unary[etype]
            edge_features = torch.cat(
                (rel_dist, edges), dim=-1) if exists(edges) else rel_dist
            kernels[etype] = kernel_fn(edge_features, basis=basis)

        for degree_out in self.fiber_out.degrees:
            output = 0
            degree_out_key = str(degree_out)

            for degree_in, m_in in self.fiber_in:
                x = inp[str(degree_in)]
                x = batched_index_select(x, neighbor_indices, dim=1)
                x = x.view(*x.shape[:3], to_order(degree_in) * m_in, 1)

                etype = f'({degree_in},{degree_out})'
                kernel = kernels[etype]
                output = output + einsum('... o i, ... i c -> ... o c', kernel,
                                         x)

            if self.pool:
                output = masked_mean(
                    output, neighbor_masks,
                    dim=2) if exists(neighbor_masks) else output.mean(dim=2)

            leading_shape = x.shape[:2] if self.pool else x.shape[:3]
            output = output.view(*leading_shape, -1, to_order(degree_out))

            outputs[degree_out_key] = output

        if self.self_interaction:
            self_interact_out = self.self_interact(inp)
            outputs = self.self_interact_sum(outputs, self_interact_out)

        return outputs
예제 #2
0
    def forward(self,
                feats,
                coors,
                mask=None,
                adj_mat=None,
                edges=None,
                return_type=None,
                return_pooled=False):
        _mask = mask

        if self.output_degrees == 1:
            return_type = 0

        if exists(self.token_emb):
            feats = self.token_emb(feats)

        assert not (
            self.attend_sparse_neighbors and not exists(adj_mat)
        ), 'adjacency matrix (adjacency_mat) or edges (edges) must be passed in'
        assert not (
            exists(edges) and not exists(self.edge_emb)
        ), 'edge embedding (num_edge_tokens & edge_dim) must be supplied if one were to train on edge types'

        if torch.is_tensor(feats):
            feats = {'0': feats[..., None]}

        b, n, d, *_, device = *feats['0'].shape, feats['0'].device

        assert d == self.dim, f'feature dimension {d} must be equal to dimension given at init {self.dim}'
        assert set(map(int, feats.keys())) == set(
            range(self.input_degrees
                  )), f'input must have {self.input_degrees} degree'

        num_degrees, neighbors, max_sparse_neighbors, valid_radius = self.num_degrees, self.num_neighbors, self.max_sparse_neighbors, self.valid_radius

        assert self.attend_sparse_neighbors or neighbors > 0, 'you must either attend to sparsely bonded neighbors, or set number of locally attended neighbors to be greater than 0'

        # create N-degrees adjacent matrix from 1st degree connections

        if exists(self.num_adj_degrees):
            if len(adj_mat.shape) == 2:
                adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b=b)

            adj_indices = adj_mat.clone().long()

            for ind in range(self.num_adj_degrees - 1):
                degree = ind + 2

                next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0
                next_degree_mask = (next_degree_adj_mat.float() -
                                    adj_mat.float()).bool()
                adj_indices.masked_fill_(next_degree_mask, degree)
                adj_mat = next_degree_adj_mat.clone()

        # se3 transformer by default cannot have a node attend to itself

        exclude_self_mask = rearrange(
            ~torch.eye(n, dtype=torch.bool, device=device), 'i j -> () i j')

        # calculate sparsely connected neighbors

        sparse_neighbor_mask = None
        num_sparse_neighbors = 0

        if self.attend_sparse_neighbors:
            assert exists(
                adj_mat
            ), 'adjacency matrix must be passed in (keyword argument adj_mat)'

            if exists(adj_mat):
                if len(adj_mat) == 2:
                    adj_mat = repeat(adj_mat, 'i j -> b i j', b=b)

            adj_mat = adj_mat.masked_select(exclude_self_mask).reshape(
                b, n, n - 1)

            adj_mat_values = adj_mat.float()
            adj_mat_max_neighbors = adj_mat_values.sum(dim=-1).max().item()

            if max_sparse_neighbors < adj_mat_max_neighbors:
                noise = torch.empty_like(adj_mat_values).uniform_(-0.01, 0.01)
                adj_mat_values += noise

            num_sparse_neighbors = int(
                min(max_sparse_neighbors, adj_mat_max_neighbors))
            values, indices = adj_mat_values.topk(num_sparse_neighbors, dim=-1)
            sparse_neighbor_mask = torch.zeros_like(adj_mat_values).scatter_(
                -1, indices, values)
            sparse_neighbor_mask = sparse_neighbor_mask > 0.5

        # exclude edge of token to itself

        indices = repeat(torch.arange(n, device=device),
                         'i -> b i j',
                         b=b,
                         j=n)
        rel_pos = rearrange(coors, 'b n d -> b n () d') - rearrange(
            coors, 'b n d -> b () n d')

        indices = indices.masked_select(exclude_self_mask).reshape(b, n, n - 1)
        rel_pos = rel_pos.masked_select(exclude_self_mask[..., None]).reshape(
            b, n, n - 1, 3)

        if exists(mask):
            mask = rearrange(mask, 'b i -> b i ()') * rearrange(
                mask, 'b j -> b () j')
            mask = mask.masked_select(exclude_self_mask).reshape(b, n, n - 1)

        if exists(edges):
            edges = self.edge_emb(edges)
            edges = edges.masked_select(exclude_self_mask[..., None]).reshape(
                b, n, n - 1, -1)

        if exists(self.adj_emb):
            adj_emb = self.adj_emb(adj_indices)
            edges = torch.cat(
                (edges, adj_emb), dim=-1) if exists(edges) else adj_emb

        rel_dist = rel_pos.norm(dim=-1)

        # use sparse neighbor mask to assign priority of bonded

        modified_rel_dist = rel_dist
        if exists(sparse_neighbor_mask):
            modified_rel_dist.masked_fill_(sparse_neighbor_mask, 0.)

        # if number of local neighbors by distance is set to 0, then only fetch the sparse neighbors defined by adjacency matrix

        if neighbors == 0:
            valid_radius = 0

        # get neighbors and neighbor mask, excluding self

        neighbors = int(min(neighbors, n - 1))
        total_neighbors = int(neighbors + num_sparse_neighbors)
        assert total_neighbors > 0, 'you must be fetching at least 1 neighbor'

        total_neighbors = int(
            min(total_neighbors, n - 1)
        )  # make sure total neighbors does not exceed the length of the sequence itself

        dist_values, nearest_indices = modified_rel_dist.topk(total_neighbors,
                                                              dim=-1,
                                                              largest=False)
        neighbor_mask = dist_values <= valid_radius

        neighbor_rel_dist = batched_index_select(rel_dist,
                                                 nearest_indices,
                                                 dim=2)
        neighbor_rel_pos = batched_index_select(rel_pos,
                                                nearest_indices,
                                                dim=2)
        neighbor_indices = batched_index_select(indices,
                                                nearest_indices,
                                                dim=2)

        if exists(mask):
            neighbor_mask = neighbor_mask & batched_index_select(
                mask, nearest_indices, dim=2)

        if exists(edges):
            edges = batched_index_select(edges, nearest_indices, dim=2)

        # calculate basis

        basis = get_basis(neighbor_rel_pos,
                          num_degrees - 1,
                          differentiable=self.differentiable_coors)

        # main logic

        edge_info = (neighbor_indices, neighbor_mask, edges)
        x = feats

        # project in

        x = self.conv_in(x, edge_info, rel_dist=neighbor_rel_dist, basis=basis)

        # transformer layers

        x = self.net(x,
                     edge_info=edge_info,
                     rel_dist=neighbor_rel_dist,
                     basis=basis)

        # project out

        x = self.conv_out(x,
                          edge_info,
                          rel_dist=neighbor_rel_dist,
                          basis=basis)

        # norm

        x = self.norm(x)

        # reduce dim if specified

        if exists(self.linear_out):
            x = self.linear_out(x)
            x = map_values(lambda t: t.squeeze(dim=2), x)

        if return_pooled:
            mask_fn = (lambda t: masked_mean(t, _mask, dim=1)
                       ) if exists(_mask) else (lambda t: t.mean(dim=1))
            x = map_values(mask_fn, x)

        if '0' in x:
            x['0'] = x['0'].squeeze(dim=-1)

        if exists(return_type):
            return x[str(return_type)]

        return x
    def forward(
        self,
        feats,
        coors,
        mask = None,
        adj_mat = None,
        edges = None,
        return_type = None,
        return_pooled = False,
        neighbor_mask = None,
        global_feats = None
    ):
        assert not (self.accept_global_feats ^ exists(global_feats)), 'you cannot pass in global features unless you init the class correctly'

        _mask = mask

        if self.output_degrees == 1:
            return_type = 0

        if exists(self.token_emb):
            feats = self.token_emb(feats)

        assert not (self.attend_sparse_neighbors and not exists(adj_mat)), 'adjacency matrix (adjacency_mat) or edges (edges) must be passed in'
        assert not (self.has_edges and not exists(edges)), 'edge embedding (num_edge_tokens & edge_dim) must be supplied if one were to train on edge types'

        if torch.is_tensor(feats):
            feats = {'0': feats[..., None]}

        if torch.is_tensor(global_feats):
            global_feats = {'0': global_feats[..., None]}

        b, n, d, *_, device = *feats['0'].shape, feats['0'].device

        assert d == self.dim_in[0], f'feature dimension {d} must be equal to dimension given at init {self.dim_in[0]}'
        assert set(map(int, feats.keys())) == set(range(self.input_degrees)), f'input must have {self.input_degrees} degree'

        num_degrees, neighbors, max_sparse_neighbors, valid_radius = self.num_degrees, self.num_neighbors, self.max_sparse_neighbors, self.valid_radius

        assert self.attend_sparse_neighbors or neighbors > 0, 'you must either attend to sparsely bonded neighbors, or set number of locally attended neighbors to be greater than 0'

        # se3 transformer by default cannot have a node attend to itself

        exclude_self_mask = rearrange(~torch.eye(n, dtype = torch.bool, device = device), 'i j -> () i j')
        remove_self = lambda t: t.masked_select(exclude_self_mask).reshape(b, n, n - 1)
        get_max_value = lambda t: torch.finfo(t.dtype).max

        # create N-degrees adjacent matrix from 1st degree connections

        if exists(self.num_adj_degrees):
            if len(adj_mat.shape) == 2:
                adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b = b)

            adj_indices = adj_mat.clone().long()

            for ind in range(self.num_adj_degrees - 1):
                degree = ind + 2

                next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0
                next_degree_mask = (next_degree_adj_mat.float() - adj_mat.float()).bool()
                adj_indices.masked_fill_(next_degree_mask, degree)
                adj_mat = next_degree_adj_mat.clone()

            adj_indices = adj_indices.masked_select(exclude_self_mask).reshape(b, n, n - 1)

        # calculate sparsely connected neighbors

        sparse_neighbor_mask = None
        num_sparse_neighbors = 0

        if self.attend_sparse_neighbors:
            assert exists(adj_mat), 'adjacency matrix must be passed in (keyword argument adj_mat)'

            if exists(adj_mat):
                if len(adj_mat.shape) == 2:
                    adj_mat = repeat(adj_mat, 'i j -> b i j', b = b)

            adj_mat = remove_self(adj_mat)

            adj_mat_values = adj_mat.float()
            adj_mat_max_neighbors = adj_mat_values.sum(dim = -1).max().item()

            if max_sparse_neighbors < adj_mat_max_neighbors:
                noise = torch.empty_like(adj_mat_values).uniform_(-0.01, 0.01)
                adj_mat_values += noise

            num_sparse_neighbors = int(min(max_sparse_neighbors, adj_mat_max_neighbors))
            values, indices = adj_mat_values.topk(num_sparse_neighbors, dim = -1)
            sparse_neighbor_mask = torch.zeros_like(adj_mat_values).scatter_(-1, indices, values)
            sparse_neighbor_mask = sparse_neighbor_mask > 0.5

        # exclude edge of token to itself

        indices = repeat(torch.arange(n, device = device), 'i -> b i j', b = b, j = n)
        rel_pos  = rearrange(coors, 'b n d -> b n () d') - rearrange(coors, 'b n d -> b () n d')

        indices = indices.masked_select(exclude_self_mask).reshape(b, n, n - 1)
        rel_pos = rel_pos.masked_select(exclude_self_mask[..., None]).reshape(b, n, n - 1, 3)

        if exists(mask):
            mask = rearrange(mask, 'b i -> b i ()') * rearrange(mask, 'b j -> b () j')
            mask = mask.masked_select(exclude_self_mask).reshape(b, n, n - 1)

        if exists(edges):
            if exists(self.edge_emb):
                edges = self.edge_emb(edges)

            edges = edges.masked_select(exclude_self_mask[..., None]).reshape(b, n, n - 1, -1)

        if exists(self.adj_emb):
            adj_emb = self.adj_emb(adj_indices)
            edges = torch.cat((edges, adj_emb), dim = -1) if exists(edges) else adj_emb

        rel_dist = rel_pos.norm(dim = -1)

        # rel_dist gets modified using adjacency or neighbor mask

        modified_rel_dist = rel_dist.clone()
        max_value = get_max_value(modified_rel_dist) # for masking out nodes from being considered as neighbors

        # neighbors

        if exists(neighbor_mask):
            neighbor_mask = remove_self(neighbor_mask)

            max_neighbors = neighbor_mask.sum(dim = -1).max().item()
            if max_neighbors > neighbors:
                print(f'neighbor_mask shows maximum number of neighbors as {max_neighbors} but specified number of neighbors is {neighbors}')

            modified_rel_dist.masked_fill_(~neighbor_mask, max_value)

        # use sparse neighbor mask to assign priority of bonded

        if exists(sparse_neighbor_mask):
            modified_rel_dist.masked_fill_(sparse_neighbor_mask, 0.)

        # mask out future nodes to high distance if causal turned on

        if self.causal:
            causal_mask = torch.ones(n, n - 1, device = device).triu().bool()
            modified_rel_dist.masked_fill_(causal_mask[None, ...], max_value)

        # if number of local neighbors by distance is set to 0, then only fetch the sparse neighbors defined by adjacency matrix

        if neighbors == 0:
            valid_radius = 0

        # get neighbors and neighbor mask, excluding self

        neighbors = int(min(neighbors, n - 1))
        total_neighbors = int(neighbors + num_sparse_neighbors)
        assert total_neighbors > 0, 'you must be fetching at least 1 neighbor'

        total_neighbors = int(min(total_neighbors, n - 1)) # make sure total neighbors does not exceed the length of the sequence itself

        dist_values, nearest_indices = modified_rel_dist.topk(total_neighbors, dim = -1, largest = False)
        neighbor_mask = dist_values <= valid_radius

        neighbor_rel_dist = batched_index_select(rel_dist, nearest_indices, dim = 2)
        neighbor_rel_pos = batched_index_select(rel_pos, nearest_indices, dim = 2)
        neighbor_indices = batched_index_select(indices, nearest_indices, dim = 2)

        if exists(mask):
            neighbor_mask = neighbor_mask & batched_index_select(mask, nearest_indices, dim = 2)

        if exists(edges):
            edges = batched_index_select(edges, nearest_indices, dim = 2)

        # calculate rotary pos emb

        rotary_pos_emb = None
        rotary_query_pos_emb = None
        rotary_key_pos_emb = None

        if self.rotary_position:
            seq = torch.arange(n, device = device)
            seq_pos_emb = self.rotary_pos_emb(seq)
            self_indices = torch.arange(neighbor_indices.shape[1], device = device)
            self_indices = repeat(self_indices, 'i -> b i ()', b = b)
            neighbor_indices_with_self = torch.cat((self_indices, neighbor_indices), dim = 2)
            pos_emb = batched_index_select(seq_pos_emb, neighbor_indices_with_self, dim = 0)

            rotary_key_pos_emb = pos_emb
            rotary_query_pos_emb = repeat(seq_pos_emb, 'n d -> b n d', b = b)

        if self.rotary_rel_dist:
            neighbor_rel_dist_with_self = F.pad(neighbor_rel_dist, (1, 0), value = 0) * 1e2
            rel_dist_pos_emb = self.rotary_pos_emb(neighbor_rel_dist_with_self)
            rotary_key_pos_emb = safe_cat(rotary_key_pos_emb, rel_dist_pos_emb, dim = -1)

            query_dist = torch.zeros(n, device = device)
            query_pos_emb = self.rotary_pos_emb(query_dist)
            query_pos_emb = repeat(query_pos_emb, 'n d -> b n d', b = b)

            rotary_query_pos_emb = safe_cat(rotary_query_pos_emb, query_pos_emb, dim = -1)

        if exists(rotary_query_pos_emb) and exists(rotary_key_pos_emb):
            rotary_pos_emb = (rotary_query_pos_emb, rotary_key_pos_emb)

        # calculate basis

        basis = get_basis(neighbor_rel_pos, num_degrees - 1, differentiable = self.differentiable_coors)

        # main logic

        edge_info = (neighbor_indices, neighbor_mask, edges)
        x = feats

        # project in

        x = self.conv_in(x, edge_info, rel_dist = neighbor_rel_dist, basis = basis)

        # preconvolution layers

        for conv, nonlin in self.convs:
            x = nonlin(x)
            x = conv(x, edge_info, rel_dist = neighbor_rel_dist, basis = basis)

        # transformer layers

        x = self.net(x, edge_info = edge_info, rel_dist = neighbor_rel_dist, basis = basis, global_feats = global_feats, pos_emb = rotary_pos_emb)

        # project out

        x = self.conv_out(x, edge_info, rel_dist = neighbor_rel_dist, basis = basis)

        # norm

        x = self.norm(x)

        # reduce dim if specified

        if exists(self.linear_out):
            x = self.linear_out(x)
            x = map_values(lambda t: t.squeeze(dim = 2), x)

        if return_pooled:
            mask_fn = (lambda t: masked_mean(t, _mask, dim = 1)) if exists(_mask) else (lambda t: t.mean(dim = 1))
            x = map_values(mask_fn, x)

        if '0' in x:
            x['0'] = x['0'].squeeze(dim = -1)

        if exists(return_type):
            return x[str(return_type)]

        return x
    def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos_emb = None):
        h, attend_self = self.heads, self.attend_self
        device, dtype = get_tensor_device_and_dtype(features)
        neighbor_indices, neighbor_mask, edges = edge_info

        max_neg_value = -torch.finfo().max

        if exists(neighbor_mask):
            neighbor_mask = rearrange(neighbor_mask, 'b i j -> b () i j')

        queries = self.to_q(features)
        values  = self.to_v(features, edge_info, rel_dist, basis)

        if self.linear_proj_keys:
            keys = self.to_k(features)
            keys = map_values(lambda val: batched_index_select(val, neighbor_indices, dim = 1), keys)
        elif not exists(self.to_k):
            keys = values
        else:
            keys = self.to_k(features, edge_info, rel_dist, basis)

        if attend_self:
            self_keys, self_values = self.to_self_k(features), self.to_self_v(features)

        if exists(global_feats):
            global_keys, global_values = self.to_global_k(global_feats), self.to_global_v(global_feats)

        outputs = {}
        for degree in features.keys():
            q, k, v = map(lambda t: t[degree], (queries, keys, values))

            q = rearrange(q, 'b i (h d) m -> b h i d m', h = h)

            if attend_self:
                self_k, self_v = map(lambda t: t[degree], (self_keys, self_values))
                self_k, self_v = map(lambda t: rearrange(t, 'b n d m -> b n () d m'), (self_k, self_v))
                k = torch.cat((self_k, k), dim = 2)
                v = torch.cat((self_v, v), dim = 2)

            if exists(pos_emb) and degree == '0':
                query_pos_emb, key_pos_emb = pos_emb
                query_pos_emb = rearrange(query_pos_emb, 'b i d -> b () i d ()')
                key_pos_emb = rearrange(key_pos_emb, 'b i j d -> b i j d ()')
                q = apply_rotary_pos_emb(q, query_pos_emb)
                k = apply_rotary_pos_emb(k, key_pos_emb)
                v = apply_rotary_pos_emb(v, key_pos_emb)

            if self.use_null_kv:
                null_k, null_v = map(lambda t: t[degree], (self.null_keys, self.null_values))
                null_k, null_v = map(lambda t: repeat(t, 'd m -> b i () d m', b = q.shape[0], i = q.shape[2]), (null_k, null_v))
                k = torch.cat((null_k, k), dim = 2)
                v = torch.cat((null_v, v), dim = 2)

            if exists(global_feats) and degree == '0':
                global_k, global_v = map(lambda t: t[degree], (global_keys, global_values))
                global_k, global_v = map(lambda t: repeat(t, 'b j d m -> b i j d m', i = k.shape[1]), (global_k, global_v))
                k = torch.cat((global_k, k), dim = 2)
                v = torch.cat((global_v, v), dim = 2)

            sim = einsum('b h i d m, b i j d m -> b h i j', q, k) * self.scale

            if exists(neighbor_mask):
                num_left_pad = sim.shape[-1] - neighbor_mask.shape[-1]
                mask = F.pad(neighbor_mask, (num_left_pad, 0), value = True)
                sim.masked_fill_(~mask, max_neg_value)

            attn = sim.softmax(dim = -1)
            out = einsum('b h i j, b i j d m -> b h i d m', attn, v)
            outputs[degree] = rearrange(out, 'b h n d m -> b n (h d) m')

        return self.to_out(outputs)
    def forward(
        self,
        inp,
        edge_info,
        rel_dist = None,
        basis = None
    ):
        splits = self.splits
        neighbor_indices, neighbor_masks, edges = edge_info
        rel_dist = rearrange(rel_dist, 'b m n -> b m n ()')

        kernels = {}
        outputs = {}

        if self.fourier_encode_dist:
            rel_dist = fourier_encode(rel_dist[..., None], num_encodings = self.num_fourier_features)

        # split basis

        basis_keys = basis.keys()
        split_basis_values = list(zip(*list(map(lambda t: fast_split(t, splits, dim = 1), basis.values()))))
        split_basis = list(map(lambda v: dict(zip(basis_keys, v)), split_basis_values))

        # go through every permutation of input degree type to output degree type

        for degree_out in self.fiber_out.degrees:
            output = 0
            degree_out_key = str(degree_out)

            for degree_in, m_in in self.fiber_in:
                etype = f'({degree_in},{degree_out})'

                x = inp[str(degree_in)]

                x = batched_index_select(x, neighbor_indices, dim = 1)
                x = x.view(*x.shape[:3], to_order(degree_in) * m_in, 1)

                kernel_fn = self.kernel_unary[etype]
                edge_features = torch.cat((rel_dist, edges), dim = -1) if exists(edges) else rel_dist

                output_chunk = None
                split_x = fast_split(x, splits, dim = 1)
                split_edge_features = fast_split(edge_features, splits, dim = 1)

                # process input, edges, and basis in chunks along the sequence dimension

                for x_chunk, edge_features, basis in zip(split_x, split_edge_features, split_basis):
                    kernel = kernel_fn(edge_features, basis = basis)
                    chunk = einsum('... o i, ... i c -> ... o c', kernel, x_chunk)
                    output_chunk = safe_cat(output_chunk, chunk, dim = 1)

                output = output + output_chunk

            if self.pool:
                output = masked_mean(output, neighbor_masks, dim = 2) if exists(neighbor_masks) else output.mean(dim = 2)

            leading_shape = x.shape[:2] if self.pool else x.shape[:3]
            output = output.view(*leading_shape, -1, to_order(degree_out))

            outputs[degree_out_key] = output

        if self.self_interaction:
            self_interact_out = self.self_interact(inp)
            outputs = self.self_interact_sum(outputs, self_interact_out)

        return outputs
예제 #6
0
    def forward(self, feats, coors, mask=None, edges=None, return_type=None):
        if exists(self.token_emb):
            feats = self.token_emb(feats)

        assert not (
            exists(edges) and not exists(self.edge_emb)
        ), 'edge embedding (num_edge_tokens & edge_dim) must be supplied if one were to train on edge types'

        if exists(edges):
            edges = self.edge_emb(edges)

        if torch.is_tensor(feats):
            feats = {'0': feats[..., None]}

        b, n, d, *_, device = *feats['0'].shape, feats['0'].device

        assert d == self.dim, f'feature dimension {d} must be equal to dimension given at init {self.dim}'
        assert set(map(int, feats.keys())) == set(
            range(self.input_degrees
                  )), f'input must have {self.input_degrees} degree'

        num_degrees, neighbors = self.num_degrees, self.num_neighbors
        neighbors = min(neighbors, n - 1)

        # exclude edge of token to itself

        exclude_self_mask = rearrange(
            ~torch.eye(n, dtype=torch.bool, device=device), 'i j -> () i j')
        indices = repeat(torch.arange(n, device=device),
                         'i -> b i j',
                         b=b,
                         j=n)
        rel_pos = rearrange(coors, 'b n d -> b n () d') - rearrange(
            coors, 'b n d -> b () n d')

        indices = indices.masked_select(exclude_self_mask).reshape(b, n, n - 1)
        rel_pos = rel_pos.masked_select(exclude_self_mask[..., None]).reshape(
            b, n, n - 1, 3)

        if exists(mask):
            mask = rearrange(mask, 'b i -> b i ()') * rearrange(
                mask, 'b j -> b () j')
            mask = mask.masked_select(exclude_self_mask).reshape(b, n, n - 1)

        if exists(edges):
            edges = edges.masked_select(exclude_self_mask[..., None]).reshape(
                b, n, n - 1, -1)

        rel_dist = rel_pos.norm(dim=-1)

        # get neighbors and neighbor mask, excluding self

        neighbor_rel_dist, nearest_indices = rel_dist.topk(neighbors,
                                                           dim=-1,
                                                           largest=False)
        neighbor_rel_pos = batched_index_select(rel_pos,
                                                nearest_indices,
                                                dim=2)
        neighbor_indices = batched_index_select(indices,
                                                nearest_indices,
                                                dim=2)

        basis = get_basis(neighbor_rel_pos,
                          num_degrees - 1,
                          differentiable=self.differentiable_coors)

        neighbor_mask = neighbor_rel_dist <= self.valid_radius

        if exists(mask):
            neighbor_mask = neighbor_mask & batched_index_select(
                mask, nearest_indices, dim=2)

        if exists(edges):
            edges = batched_index_select(edges, nearest_indices, dim=2)

        # main logic

        edge_info = (neighbor_indices, neighbor_mask, edges)
        x = feats

        # project in

        x = self.conv_in(x, edge_info, rel_dist=neighbor_rel_dist, basis=basis)

        # transformer layers

        x = self.net(x,
                     edge_info=edge_info,
                     rel_dist=neighbor_rel_dist,
                     basis=basis)

        # project out

        x = self.conv_out(x,
                          edge_info,
                          rel_dist=neighbor_rel_dist,
                          basis=basis)

        # norm

        x = self.norm(x)

        # reduce dim if specified

        if exists(self.linear_out):
            x = self.linear_out(x)
            x = {k: v.squeeze(dim=2) for k, v in x.items()}

        if exists(return_type):
            return x[str(return_type)]

        return x