Esempio n. 1
0
    def forward(self,
                feats,
                coors,
                mask=None,
                adj_mat=None,
                edges=None,
                return_type=None,
                return_pooled=False):
        _mask = mask

        if self.output_degrees == 1:
            return_type = 0

        if exists(self.token_emb):
            feats = self.token_emb(feats)

        assert not (
            self.attend_sparse_neighbors and not exists(adj_mat)
        ), 'adjacency matrix (adjacency_mat) or edges (edges) must be passed in'
        assert not (
            exists(edges) and not exists(self.edge_emb)
        ), 'edge embedding (num_edge_tokens & edge_dim) must be supplied if one were to train on edge types'

        if torch.is_tensor(feats):
            feats = {'0': feats[..., None]}

        b, n, d, *_, device = *feats['0'].shape, feats['0'].device

        assert d == self.dim, f'feature dimension {d} must be equal to dimension given at init {self.dim}'
        assert set(map(int, feats.keys())) == set(
            range(self.input_degrees
                  )), f'input must have {self.input_degrees} degree'

        num_degrees, neighbors, max_sparse_neighbors, valid_radius = self.num_degrees, self.num_neighbors, self.max_sparse_neighbors, self.valid_radius

        assert self.attend_sparse_neighbors or neighbors > 0, 'you must either attend to sparsely bonded neighbors, or set number of locally attended neighbors to be greater than 0'

        # create N-degrees adjacent matrix from 1st degree connections

        if exists(self.num_adj_degrees):
            if len(adj_mat.shape) == 2:
                adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b=b)

            adj_indices = adj_mat.clone().long()

            for ind in range(self.num_adj_degrees - 1):
                degree = ind + 2

                next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0
                next_degree_mask = (next_degree_adj_mat.float() -
                                    adj_mat.float()).bool()
                adj_indices.masked_fill_(next_degree_mask, degree)
                adj_mat = next_degree_adj_mat.clone()

        # se3 transformer by default cannot have a node attend to itself

        exclude_self_mask = rearrange(
            ~torch.eye(n, dtype=torch.bool, device=device), 'i j -> () i j')

        # calculate sparsely connected neighbors

        sparse_neighbor_mask = None
        num_sparse_neighbors = 0

        if self.attend_sparse_neighbors:
            assert exists(
                adj_mat
            ), 'adjacency matrix must be passed in (keyword argument adj_mat)'

            if exists(adj_mat):
                if len(adj_mat) == 2:
                    adj_mat = repeat(adj_mat, 'i j -> b i j', b=b)

            adj_mat = adj_mat.masked_select(exclude_self_mask).reshape(
                b, n, n - 1)

            adj_mat_values = adj_mat.float()
            adj_mat_max_neighbors = adj_mat_values.sum(dim=-1).max().item()

            if max_sparse_neighbors < adj_mat_max_neighbors:
                noise = torch.empty_like(adj_mat_values).uniform_(-0.01, 0.01)
                adj_mat_values += noise

            num_sparse_neighbors = int(
                min(max_sparse_neighbors, adj_mat_max_neighbors))
            values, indices = adj_mat_values.topk(num_sparse_neighbors, dim=-1)
            sparse_neighbor_mask = torch.zeros_like(adj_mat_values).scatter_(
                -1, indices, values)
            sparse_neighbor_mask = sparse_neighbor_mask > 0.5

        # exclude edge of token to itself

        indices = repeat(torch.arange(n, device=device),
                         'i -> b i j',
                         b=b,
                         j=n)
        rel_pos = rearrange(coors, 'b n d -> b n () d') - rearrange(
            coors, 'b n d -> b () n d')

        indices = indices.masked_select(exclude_self_mask).reshape(b, n, n - 1)
        rel_pos = rel_pos.masked_select(exclude_self_mask[..., None]).reshape(
            b, n, n - 1, 3)

        if exists(mask):
            mask = rearrange(mask, 'b i -> b i ()') * rearrange(
                mask, 'b j -> b () j')
            mask = mask.masked_select(exclude_self_mask).reshape(b, n, n - 1)

        if exists(edges):
            edges = self.edge_emb(edges)
            edges = edges.masked_select(exclude_self_mask[..., None]).reshape(
                b, n, n - 1, -1)

        if exists(self.adj_emb):
            adj_emb = self.adj_emb(adj_indices)
            edges = torch.cat(
                (edges, adj_emb), dim=-1) if exists(edges) else adj_emb

        rel_dist = rel_pos.norm(dim=-1)

        # use sparse neighbor mask to assign priority of bonded

        modified_rel_dist = rel_dist
        if exists(sparse_neighbor_mask):
            modified_rel_dist.masked_fill_(sparse_neighbor_mask, 0.)

        # if number of local neighbors by distance is set to 0, then only fetch the sparse neighbors defined by adjacency matrix

        if neighbors == 0:
            valid_radius = 0

        # get neighbors and neighbor mask, excluding self

        neighbors = int(min(neighbors, n - 1))
        total_neighbors = int(neighbors + num_sparse_neighbors)
        assert total_neighbors > 0, 'you must be fetching at least 1 neighbor'

        total_neighbors = int(
            min(total_neighbors, n - 1)
        )  # make sure total neighbors does not exceed the length of the sequence itself

        dist_values, nearest_indices = modified_rel_dist.topk(total_neighbors,
                                                              dim=-1,
                                                              largest=False)
        neighbor_mask = dist_values <= valid_radius

        neighbor_rel_dist = batched_index_select(rel_dist,
                                                 nearest_indices,
                                                 dim=2)
        neighbor_rel_pos = batched_index_select(rel_pos,
                                                nearest_indices,
                                                dim=2)
        neighbor_indices = batched_index_select(indices,
                                                nearest_indices,
                                                dim=2)

        if exists(mask):
            neighbor_mask = neighbor_mask & batched_index_select(
                mask, nearest_indices, dim=2)

        if exists(edges):
            edges = batched_index_select(edges, nearest_indices, dim=2)

        # calculate basis

        basis = get_basis(neighbor_rel_pos,
                          num_degrees - 1,
                          differentiable=self.differentiable_coors)

        # main logic

        edge_info = (neighbor_indices, neighbor_mask, edges)
        x = feats

        # project in

        x = self.conv_in(x, edge_info, rel_dist=neighbor_rel_dist, basis=basis)

        # transformer layers

        x = self.net(x,
                     edge_info=edge_info,
                     rel_dist=neighbor_rel_dist,
                     basis=basis)

        # project out

        x = self.conv_out(x,
                          edge_info,
                          rel_dist=neighbor_rel_dist,
                          basis=basis)

        # norm

        x = self.norm(x)

        # reduce dim if specified

        if exists(self.linear_out):
            x = self.linear_out(x)
            x = map_values(lambda t: t.squeeze(dim=2), x)

        if return_pooled:
            mask_fn = (lambda t: masked_mean(t, _mask, dim=1)
                       ) if exists(_mask) else (lambda t: t.mean(dim=1))
            x = map_values(mask_fn, x)

        if '0' in x:
            x['0'] = x['0'].squeeze(dim=-1)

        if exists(return_type):
            return x[str(return_type)]

        return x
def test_basis():
    max_degree = 3
    x = torch.randn(2, 1024, 3)
    basis = get_basis(x, max_degree)
    assert len(basis.keys()) == (max_degree +
                                 1)**2, 'correct number of basis kernels'
    def forward(
        self,
        feats,
        coors,
        mask = None,
        adj_mat = None,
        edges = None,
        return_type = None,
        return_pooled = False,
        neighbor_mask = None,
        global_feats = None
    ):
        assert not (self.accept_global_feats ^ exists(global_feats)), 'you cannot pass in global features unless you init the class correctly'

        _mask = mask

        if self.output_degrees == 1:
            return_type = 0

        if exists(self.token_emb):
            feats = self.token_emb(feats)

        assert not (self.attend_sparse_neighbors and not exists(adj_mat)), 'adjacency matrix (adjacency_mat) or edges (edges) must be passed in'
        assert not (self.has_edges and not exists(edges)), 'edge embedding (num_edge_tokens & edge_dim) must be supplied if one were to train on edge types'

        if torch.is_tensor(feats):
            feats = {'0': feats[..., None]}

        if torch.is_tensor(global_feats):
            global_feats = {'0': global_feats[..., None]}

        b, n, d, *_, device = *feats['0'].shape, feats['0'].device

        assert d == self.dim_in[0], f'feature dimension {d} must be equal to dimension given at init {self.dim_in[0]}'
        assert set(map(int, feats.keys())) == set(range(self.input_degrees)), f'input must have {self.input_degrees} degree'

        num_degrees, neighbors, max_sparse_neighbors, valid_radius = self.num_degrees, self.num_neighbors, self.max_sparse_neighbors, self.valid_radius

        assert self.attend_sparse_neighbors or neighbors > 0, 'you must either attend to sparsely bonded neighbors, or set number of locally attended neighbors to be greater than 0'

        # se3 transformer by default cannot have a node attend to itself

        exclude_self_mask = rearrange(~torch.eye(n, dtype = torch.bool, device = device), 'i j -> () i j')
        remove_self = lambda t: t.masked_select(exclude_self_mask).reshape(b, n, n - 1)
        get_max_value = lambda t: torch.finfo(t.dtype).max

        # create N-degrees adjacent matrix from 1st degree connections

        if exists(self.num_adj_degrees):
            if len(adj_mat.shape) == 2:
                adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b = b)

            adj_indices = adj_mat.clone().long()

            for ind in range(self.num_adj_degrees - 1):
                degree = ind + 2

                next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0
                next_degree_mask = (next_degree_adj_mat.float() - adj_mat.float()).bool()
                adj_indices.masked_fill_(next_degree_mask, degree)
                adj_mat = next_degree_adj_mat.clone()

            adj_indices = adj_indices.masked_select(exclude_self_mask).reshape(b, n, n - 1)

        # calculate sparsely connected neighbors

        sparse_neighbor_mask = None
        num_sparse_neighbors = 0

        if self.attend_sparse_neighbors:
            assert exists(adj_mat), 'adjacency matrix must be passed in (keyword argument adj_mat)'

            if exists(adj_mat):
                if len(adj_mat.shape) == 2:
                    adj_mat = repeat(adj_mat, 'i j -> b i j', b = b)

            adj_mat = remove_self(adj_mat)

            adj_mat_values = adj_mat.float()
            adj_mat_max_neighbors = adj_mat_values.sum(dim = -1).max().item()

            if max_sparse_neighbors < adj_mat_max_neighbors:
                noise = torch.empty_like(adj_mat_values).uniform_(-0.01, 0.01)
                adj_mat_values += noise

            num_sparse_neighbors = int(min(max_sparse_neighbors, adj_mat_max_neighbors))
            values, indices = adj_mat_values.topk(num_sparse_neighbors, dim = -1)
            sparse_neighbor_mask = torch.zeros_like(adj_mat_values).scatter_(-1, indices, values)
            sparse_neighbor_mask = sparse_neighbor_mask > 0.5

        # exclude edge of token to itself

        indices = repeat(torch.arange(n, device = device), 'i -> b i j', b = b, j = n)
        rel_pos  = rearrange(coors, 'b n d -> b n () d') - rearrange(coors, 'b n d -> b () n d')

        indices = indices.masked_select(exclude_self_mask).reshape(b, n, n - 1)
        rel_pos = rel_pos.masked_select(exclude_self_mask[..., None]).reshape(b, n, n - 1, 3)

        if exists(mask):
            mask = rearrange(mask, 'b i -> b i ()') * rearrange(mask, 'b j -> b () j')
            mask = mask.masked_select(exclude_self_mask).reshape(b, n, n - 1)

        if exists(edges):
            if exists(self.edge_emb):
                edges = self.edge_emb(edges)

            edges = edges.masked_select(exclude_self_mask[..., None]).reshape(b, n, n - 1, -1)

        if exists(self.adj_emb):
            adj_emb = self.adj_emb(adj_indices)
            edges = torch.cat((edges, adj_emb), dim = -1) if exists(edges) else adj_emb

        rel_dist = rel_pos.norm(dim = -1)

        # rel_dist gets modified using adjacency or neighbor mask

        modified_rel_dist = rel_dist.clone()
        max_value = get_max_value(modified_rel_dist) # for masking out nodes from being considered as neighbors

        # neighbors

        if exists(neighbor_mask):
            neighbor_mask = remove_self(neighbor_mask)

            max_neighbors = neighbor_mask.sum(dim = -1).max().item()
            if max_neighbors > neighbors:
                print(f'neighbor_mask shows maximum number of neighbors as {max_neighbors} but specified number of neighbors is {neighbors}')

            modified_rel_dist.masked_fill_(~neighbor_mask, max_value)

        # use sparse neighbor mask to assign priority of bonded

        if exists(sparse_neighbor_mask):
            modified_rel_dist.masked_fill_(sparse_neighbor_mask, 0.)

        # mask out future nodes to high distance if causal turned on

        if self.causal:
            causal_mask = torch.ones(n, n - 1, device = device).triu().bool()
            modified_rel_dist.masked_fill_(causal_mask[None, ...], max_value)

        # if number of local neighbors by distance is set to 0, then only fetch the sparse neighbors defined by adjacency matrix

        if neighbors == 0:
            valid_radius = 0

        # get neighbors and neighbor mask, excluding self

        neighbors = int(min(neighbors, n - 1))
        total_neighbors = int(neighbors + num_sparse_neighbors)
        assert total_neighbors > 0, 'you must be fetching at least 1 neighbor'

        total_neighbors = int(min(total_neighbors, n - 1)) # make sure total neighbors does not exceed the length of the sequence itself

        dist_values, nearest_indices = modified_rel_dist.topk(total_neighbors, dim = -1, largest = False)
        neighbor_mask = dist_values <= valid_radius

        neighbor_rel_dist = batched_index_select(rel_dist, nearest_indices, dim = 2)
        neighbor_rel_pos = batched_index_select(rel_pos, nearest_indices, dim = 2)
        neighbor_indices = batched_index_select(indices, nearest_indices, dim = 2)

        if exists(mask):
            neighbor_mask = neighbor_mask & batched_index_select(mask, nearest_indices, dim = 2)

        if exists(edges):
            edges = batched_index_select(edges, nearest_indices, dim = 2)

        # calculate rotary pos emb

        rotary_pos_emb = None
        rotary_query_pos_emb = None
        rotary_key_pos_emb = None

        if self.rotary_position:
            seq = torch.arange(n, device = device)
            seq_pos_emb = self.rotary_pos_emb(seq)
            self_indices = torch.arange(neighbor_indices.shape[1], device = device)
            self_indices = repeat(self_indices, 'i -> b i ()', b = b)
            neighbor_indices_with_self = torch.cat((self_indices, neighbor_indices), dim = 2)
            pos_emb = batched_index_select(seq_pos_emb, neighbor_indices_with_self, dim = 0)

            rotary_key_pos_emb = pos_emb
            rotary_query_pos_emb = repeat(seq_pos_emb, 'n d -> b n d', b = b)

        if self.rotary_rel_dist:
            neighbor_rel_dist_with_self = F.pad(neighbor_rel_dist, (1, 0), value = 0) * 1e2
            rel_dist_pos_emb = self.rotary_pos_emb(neighbor_rel_dist_with_self)
            rotary_key_pos_emb = safe_cat(rotary_key_pos_emb, rel_dist_pos_emb, dim = -1)

            query_dist = torch.zeros(n, device = device)
            query_pos_emb = self.rotary_pos_emb(query_dist)
            query_pos_emb = repeat(query_pos_emb, 'n d -> b n d', b = b)

            rotary_query_pos_emb = safe_cat(rotary_query_pos_emb, query_pos_emb, dim = -1)

        if exists(rotary_query_pos_emb) and exists(rotary_key_pos_emb):
            rotary_pos_emb = (rotary_query_pos_emb, rotary_key_pos_emb)

        # calculate basis

        basis = get_basis(neighbor_rel_pos, num_degrees - 1, differentiable = self.differentiable_coors)

        # main logic

        edge_info = (neighbor_indices, neighbor_mask, edges)
        x = feats

        # project in

        x = self.conv_in(x, edge_info, rel_dist = neighbor_rel_dist, basis = basis)

        # preconvolution layers

        for conv, nonlin in self.convs:
            x = nonlin(x)
            x = conv(x, edge_info, rel_dist = neighbor_rel_dist, basis = basis)

        # transformer layers

        x = self.net(x, edge_info = edge_info, rel_dist = neighbor_rel_dist, basis = basis, global_feats = global_feats, pos_emb = rotary_pos_emb)

        # project out

        x = self.conv_out(x, edge_info, rel_dist = neighbor_rel_dist, basis = basis)

        # norm

        x = self.norm(x)

        # reduce dim if specified

        if exists(self.linear_out):
            x = self.linear_out(x)
            x = map_values(lambda t: t.squeeze(dim = 2), x)

        if return_pooled:
            mask_fn = (lambda t: masked_mean(t, _mask, dim = 1)) if exists(_mask) else (lambda t: t.mean(dim = 1))
            x = map_values(mask_fn, x)

        if '0' in x:
            x['0'] = x['0'].squeeze(dim = -1)

        if exists(return_type):
            return x[str(return_type)]

        return x
Esempio n. 4
0
    def forward(self, feats, coors, mask=None, edges=None, return_type=None):
        if exists(self.token_emb):
            feats = self.token_emb(feats)

        assert not (
            exists(edges) and not exists(self.edge_emb)
        ), 'edge embedding (num_edge_tokens & edge_dim) must be supplied if one were to train on edge types'

        if exists(edges):
            edges = self.edge_emb(edges)

        if torch.is_tensor(feats):
            feats = {'0': feats[..., None]}

        b, n, d, *_, device = *feats['0'].shape, feats['0'].device

        assert d == self.dim, f'feature dimension {d} must be equal to dimension given at init {self.dim}'
        assert set(map(int, feats.keys())) == set(
            range(self.input_degrees
                  )), f'input must have {self.input_degrees} degree'

        num_degrees, neighbors = self.num_degrees, self.num_neighbors
        neighbors = min(neighbors, n - 1)

        # exclude edge of token to itself

        exclude_self_mask = rearrange(
            ~torch.eye(n, dtype=torch.bool, device=device), 'i j -> () i j')
        indices = repeat(torch.arange(n, device=device),
                         'i -> b i j',
                         b=b,
                         j=n)
        rel_pos = rearrange(coors, 'b n d -> b n () d') - rearrange(
            coors, 'b n d -> b () n d')

        indices = indices.masked_select(exclude_self_mask).reshape(b, n, n - 1)
        rel_pos = rel_pos.masked_select(exclude_self_mask[..., None]).reshape(
            b, n, n - 1, 3)

        if exists(mask):
            mask = rearrange(mask, 'b i -> b i ()') * rearrange(
                mask, 'b j -> b () j')
            mask = mask.masked_select(exclude_self_mask).reshape(b, n, n - 1)

        if exists(edges):
            edges = edges.masked_select(exclude_self_mask[..., None]).reshape(
                b, n, n - 1, -1)

        rel_dist = rel_pos.norm(dim=-1)

        # get neighbors and neighbor mask, excluding self

        neighbor_rel_dist, nearest_indices = rel_dist.topk(neighbors,
                                                           dim=-1,
                                                           largest=False)
        neighbor_rel_pos = batched_index_select(rel_pos,
                                                nearest_indices,
                                                dim=2)
        neighbor_indices = batched_index_select(indices,
                                                nearest_indices,
                                                dim=2)

        basis = get_basis(neighbor_rel_pos,
                          num_degrees - 1,
                          differentiable=self.differentiable_coors)

        neighbor_mask = neighbor_rel_dist <= self.valid_radius

        if exists(mask):
            neighbor_mask = neighbor_mask & batched_index_select(
                mask, nearest_indices, dim=2)

        if exists(edges):
            edges = batched_index_select(edges, nearest_indices, dim=2)

        # main logic

        edge_info = (neighbor_indices, neighbor_mask, edges)
        x = feats

        # project in

        x = self.conv_in(x, edge_info, rel_dist=neighbor_rel_dist, basis=basis)

        # transformer layers

        x = self.net(x,
                     edge_info=edge_info,
                     rel_dist=neighbor_rel_dist,
                     basis=basis)

        # project out

        x = self.conv_out(x,
                          edge_info,
                          rel_dist=neighbor_rel_dist,
                          basis=basis)

        # norm

        x = self.norm(x)

        # reduce dim if specified

        if exists(self.linear_out):
            x = self.linear_out(x)
            x = {k: v.squeeze(dim=2) for k, v in x.items()}

        if exists(return_type):
            return x[str(return_type)]

        return x