def test_empty_inputs(): tp = FullyConnectedTensorProduct('0e + 1e', '0e + 1e', '0e + 1e') out = tp(torch.randn(2, 1, 0, 1, 4), torch.randn(1, 2, 0, 3, 4)) assert out.shape == (2, 2, 0, 3, 4) out = tp.right(torch.randn(1, 2, 0, 3, 4)) assert out.shape == (1, 2, 0, 3, 4, 4)
class Convolution(torch.nn.Module): r"""convolution on voxels Parameters ---------- irreps_in : `e3nn.o3.Irreps` input irreps irreps_out : `e3nn.o3.Irreps` output irreps irreps_sh : `e3nn.o3.Irreps` set typically to ``o3.Irreps.spherical_harmonics(lmax)`` diameter : float diameter of the filter in physical units num_radial_basis : int number of radial basis functions steps : tuple of float size of the pixel in physical units """ def __init__(self, irreps_in, irreps_out, irreps_sh, diameter, num_radial_basis, steps=(1.0, 1.0, 1.0), **kwargs): super().__init__() self.irreps_in = o3.Irreps(irreps_in) self.irreps_out = o3.Irreps(irreps_out) self.irreps_sh = o3.Irreps(irreps_sh) self.num_radial_basis = num_radial_basis # self-connection self.sc = Linear(self.irreps_in, self.irreps_out) # connection with neighbors r = diameter / 2 s = math.floor(r / steps[0]) x = torch.arange(-s, s + 1.0) * steps[0] s = math.floor(r / steps[1]) y = torch.arange(-s, s + 1.0) * steps[1] s = math.floor(r / steps[2]) z = torch.arange(-s, s + 1.0) * steps[2] lattice = torch.stack(torch.meshgrid(x, y, z), dim=-1) # [x, y, z, R^3] self.register_buffer('lattice', lattice) if 'padding' not in kwargs: kwargs['padding'] = tuple(s // 2 for s in lattice.shape[:3]) self.kwargs = kwargs emb = soft_one_hot_linspace( x=lattice.norm(dim=-1), start=0.0, end=r, number=self.num_radial_basis, basis='smooth_finite', cutoff=True, ) self.register_buffer('emb', emb) sh = o3.spherical_harmonics( l=self.irreps_sh, x=lattice, normalize=True, normalization='component' ) # [x, y, z, irreps_sh.dim] self.register_buffer('sh', sh) self.tp = FullyConnectedTensorProduct(self.irreps_in, self.irreps_sh, self.irreps_out, shared_weights=False) self.weight = torch.nn.Parameter(torch.randn(self.num_radial_basis, self.tp.weight_numel)) def kernel(self): weight = self.emb @ self.weight weight = weight / (self.sh.shape[0] * self.sh.shape[1] * self.sh.shape[2]) kernel = self.tp.right(self.sh, weight) # [x, y, z, irreps_in.dim, irreps_out.dim] kernel = torch.einsum('xyzio->oixyz', kernel) return kernel def forward(self, x): r""" Parameters ---------- x : `torch.Tensor` tensor of shape ``(batch, irreps_in.dim, x, y, z)`` Returns ------- `torch.Tensor` tensor of shape ``(batch, irreps_out.dim, x, y, z)`` """ sc = self.sc(x.transpose(1, 4)).transpose(1, 4) return sc + torch.nn.functional.conv3d(x, self.kernel(), **self.kwargs)