def test_equivariance_only_sparse_neighbors(): model = SE3Transformer( dim = 64, depth = 1, attend_self = True, num_degrees = 2, output_degrees = 2, num_neighbors = 0, attend_sparse_neighbors = True, num_adj_degrees = 2, adj_dim = 4 ) feats = torch.randn(1, 32, 64) coors = torch.randn(1, 32, 3) mask = torch.ones(1, 32).bool() seq = torch.arange(32) adj_mat = (seq[:, None] >= (seq[None, :] - 1)) & (seq[:, None] <= (seq[None, :] + 1)) R = rot(15, 0, 45) out1 = model(feats, coors @ R, mask, adj_mat = adj_mat, return_type = 1) out2 = model(feats, coors, mask, adj_mat = adj_mat, return_type = 1) @ R diff = (out1 - out2).max() assert diff < 1e-4, 'is not equivariant'
def test_equivariance_with_type_one_input(): model = SE3Transformer(dim=64, depth=1, attend_self=True, num_neighbors=4, num_degrees=2, input_degrees=2, output_degrees=2) atom_features = torch.randn(1, 32, 64, 1) pred_coors = torch.randn(1, 32, 64, 3) coors = torch.randn(1, 32, 3) mask = torch.ones(1, 32).bool() R = rot(15, 0, 45) out1 = model({ '0': atom_features, '1': pred_coors @ R }, coors @ R, mask, return_type=1) out2 = model( { '0': atom_features, '1': pred_coors }, coors, mask, return_type=1) @ R diff = (out1 - out2).max() assert diff < 1e-4, 'is not equivariant'
def test_equivariance_with_reversible_network(): model = SE3Transformer(dim=64, depth=1, attend_self=True, num_neighbors=4, num_degrees=2, output_degrees=2, reversible=True) feats = torch.randn(1, 32, 64) coors = torch.randn(1, 32, 3) mask = torch.ones(1, 32).bool() R = rot(15, 0, 45) out1 = model(feats, coors @ R, mask, return_type=1) out2 = model(feats, coors, mask, return_type=1) @ R diff = (out1 - out2).max() assert diff < 1e-4, 'is not equivariant'
def test_equivariance_linear_proj_keys(): model = SE3Transformer(dim=64, depth=1, attend_self=True, num_neighbors=4, num_degrees=2, output_degrees=2, fourier_encode_dist=True, linear_proj_keys=True) feats = torch.randn(1, 32, 64) coors = torch.randn(1, 32, 3) mask = torch.ones(1, 32).bool() R = rot(15, 0, 45) out1 = model(feats, coors @ R, mask, return_type=1) out2 = model(feats, coors, mask, return_type=1) @ R diff = (out1 - out2).max() assert diff < 1e-4, 'is not equivariant'