def __init__(self, num_classes): super().__init__() R = partial(CosineBasisModel, max_radius=3.0, number_of_basis=3, h=100, L=3, act=relu) K = partial(Kernel, RadialModel=R) mul = 7 layers = [] Rs = [(1, 0, +1)] for i in range(3): scalars = [(mul, l, p) for mul, l, p in [(mul, 0, +1), (mul, 0, -1)] if haspath(Rs, l, p)] act_scalars = [(mul, relu if p == 1 else tanh) for mul, l, p in scalars] nonscalars = [(mul, l, p) for mul, l, p in [(mul, 1, +1), (mul, 1, -1)] if haspath(Rs, l, p)] gates = [(sum(mul for mul, l, p in nonscalars), 0, +1)] act_gates = [(-1, sigmoid)] print("layer {}: from {} to {}".format(i, rs.format_Rs(Rs), rs.format_Rs(scalars + nonscalars))) act = GatedBlockParity(scalars, act_scalars, gates, act_gates, nonscalars) conv = Convolution(K(Rs, act.Rs_in)) block = torch.nn.ModuleList([conv, act]) layers.append(block) Rs = act.Rs_out act = GatedBlockParity([(mul, 0, +1), (mul, 0, -1)], [(mul, relu), (mul, tanh)], [], [], []) conv = Convolution(K(Rs, act.Rs_in)) block = torch.nn.ModuleList([conv, act]) layers.append(block) self.firstlayers = torch.nn.ModuleList(layers) # the last layer is not equivariant, it is allowed to mix even and odds scalars self.lastlayers = torch.nn.Sequential(AvgSpacial(), torch.nn.Linear(mul + mul, num_classes))
def rotation_gated_block(self, K): """Test rotation equivariance on GatedBlock and dependencies.""" with torch_default_dtype(torch.float64): Rs_in = [(2, 0), (0, 1), (2, 2)] Rs_out = [(2, 0), (2, 1), (2, 2)] K = partial(K, RadialModel=ConstantRadialModel) act = GatedBlock(Rs_out, scalar_activation=sigmoid, gate_activation=sigmoid) conv = Convolution(K(Rs_in, act.Rs_in)) abc = torch.randn(3) rot_geo = o3.rot(*abc) D_in = rs.rep(Rs_in, *abc) D_out = rs.rep(Rs_out, *abc) fea = torch.randn(1, 4, rs.dim(Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo)))) x2 = act( conv(torch.einsum("ij,zaj->zai", (D_in, fea)), torch.einsum("ij,zaj->zai", rot_geo, geo))) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def parity_rotation_gated_block_parity(self, K): """Test parity and rotation equivariance on GatedBlockParity and dependencies.""" with torch_default_dtype(torch.float64): mul = 2 Rs_in = [(mul, l, p) for l in range(3 + 1) for p in [-1, 1]] K = partial(K, RadialModel=ConstantRadialModel) scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu), (mul, absolute)] rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1), (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)] n = 3 * mul gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)] act = GatedBlockParity(*scalars, *gates, rs_nonscalars) conv = Convolution(K(Rs_in, act.Rs_in)) abc = torch.randn(3) rot_geo = -o3.rot(*abc) D_in = rs.rep(Rs_in, *abc, 1) D_out = rs.rep(act.Rs_out, *abc, 1) fea = torch.randn(1, 4, rs.dim(Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo)))) x2 = act( conv(torch.einsum("ij,zaj->zai", (D_in, fea)), torch.einsum("ij,zaj->zai", rot_geo, geo))) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def test1(self): with torch_default_dtype(torch.float64): Rs_in = [(3, 0), (3, 1), (2, 0), (1, 2)] Rs_out = [(3, 0), (3, 1), (1, 2), (3, 0)] f = GatedBlock(Rs_out, rescaled_act.Softplus(beta=5), rescaled_act.sigmoid) c = Convolution(Kernel(Rs_in, f.Rs_in, ConstantRadialModel)) abc = torch.randn(3) D_in = o3.direct_sum( * [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)]) D_out = o3.direct_sum(*[ o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul) ]) x = torch.randn(1, 5, sum(mul * (2 * l + 1) for mul, l in Rs_in)) geo = torch.randn(1, 5, 3) rx = torch.einsum("ij,zaj->zai", (D_in, x)) rgeo = geo @ o3.rot(*abc).t() y = f(c(x, geo), dim=2) ry = torch.einsum("ij,zaj->zai", (D_out, y)) self.assertLess((f(c(rx, rgeo)) - ry).norm(), 1e-10 * ry.norm())
def check_rotation_parity(batch: int = 10, n_atoms: int = 25): # Setup the network. K = partial(Kernel, RadialModel=ConstantRadialModel) Rs_in = [(1, 0, +1)] act = GatedBlockParity( Rs_scalars=[(4, 0, +1)], act_scalars=[(-1, relu)], Rs_gates=[(8, 0, +1)], act_gates=[(-1, tanh)], Rs_nonscalars=[(4, 1, -1), (4, 2, +1)] ) conv = Convolution(K, Rs_in, act.Rs_in) Rs_out = act.Rs_out # Setup the data. The geometry, input features, and output features must all rotate and observe parity. abc = torch.randn(3) # Rotation seed of euler angles. rot_geo = -o3.rot(*abc) # Negative because geometry has odd parity. i.e. improper rotation. D_in = rs.rep(Rs_in, *abc, parity=1) D_out = rs.rep(Rs_out, *abc, parity=1) feat = torch.randn(batch, n_atoms, rs.dim(Rs_in)) # Transforms with wigner D matrix and parity. geo = torch.randn(batch, n_atoms, 3) # Transforms with rotation matrix and parity. # Test equivariance. F = act(conv(feat, geo)) RF = torch.einsum("ij,zkj->zki", D_out, F) FR = act(conv(feat @ D_in.t(), geo @ rot_geo.t())) return (RF - FR).norm() < 10e-5 * RF.norm()
def check_rotation(batch: int = 10, n_atoms: int = 25): # Setup the network. K = partial(Kernel, RadialModel=ConstantRadialModel) Rs_in = [(1, 0), (1, 1)] Rs_out = [(1, 0), (1, 1), (1, 2)] act = GatedBlock( Rs_out, scalar_activation=sigmoid, gate_activation=absolute, ) conv = Convolution(K, Rs_in, act.Rs_in) # Setup the data. The geometry, input features, and output features must all rotate. abc = torch.randn(3) # Rotation seed of euler angles. rot_geo = o3.rot(*abc) D_in = rs.rep(Rs_in, *abc) D_out = rs.rep(Rs_out, *abc) feat = torch.randn(batch, n_atoms, rs.dim(Rs_in)) # Transforms with wigner D matrix geo = torch.randn(batch, n_atoms, 3) # Transforms with rotation matrix. # Test equivariance. F = act(conv(feat, geo)) RF = torch.einsum("ij,zkj->zki", D_out, F) FR = act(conv(feat @ D_in.t(), geo @ rot_geo.t())) return (RF - FR).norm() < 10e-5 * RF.norm()
def test3(self): """Test rotation equivariance on GatedBlock and dependencies.""" with torch_default_dtype(torch.float64): Rs_in = [(2, 0), (0, 1), (2, 2)] Rs_out = [(2, 0), (2, 1), (2, 2)] K = partial(Kernel, RadialModel=ConstantRadialModel) act = GatedBlock(Rs_out, scalar_activation=sigmoid, gate_activation=sigmoid) conv = Convolution(K, Rs_in, act.Rs_in) abc = torch.randn(3) rot_geo = o3.rot(*abc) D_in = o3.direct_sum( * [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)]) D_out = o3.direct_sum(*[ o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul) ]) fea = torch.randn(1, 4, sum(mul * (2 * l + 1) for mul, l in Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo)))) x2 = act( conv(torch.einsum("ij,zaj->zai", (D_in, fea)), torch.einsum("ij,zaj->zai", rot_geo, geo))) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def test5(self): """Test parity equivariance on GatedBlockParity and dependencies.""" with torch_default_dtype(torch.float64): mul = 2 Rs_in = [(mul, l, p) for l in range(6) for p in [-1, 1]] K = partial(Kernel, RadialModel=ConstantRadialModel) scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu), (mul, absolute)] rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1), (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)] n = 3 * mul gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)] act = GatedBlockParity(*scalars, *gates, rs_nonscalars) conv = Convolution(K, Rs_in, act.Rs_in) D_in = o3.direct_sum(*[ p * torch.eye(2 * l + 1) for mul, l, p in Rs_in for _ in range(mul) ]) D_out = o3.direct_sum(*[ p * torch.eye(2 * l + 1) for mul, l, p in act.Rs_out for _ in range(mul) ]) fea = torch.randn(1, 4, sum(mul * (2 * l + 1) for mul, l, p in Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo)))) x2 = act(conv(torch.einsum("ij,zaj->zai", (D_in, fea)), -geo)) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def __init__(self, num_radial=30, max_radius=2): super().__init__() sp = rescaled_act.Softplus(beta=5) RadialModel = partial(CosineBasisModel, max_radius=max_radius, number_of_basis=num_radial, h=100, L=2, act=sp) K = partial(Kernel, RadialModel=RadialModel) self.conv = Convolution(K, [(1, 0)], [(1, 1)])
def get_kernel_conv_kernelconv(self, seed, normalization): torch.manual_seed(seed) K = partial(Kernel, RadialModel=ConstantRadialModel, normalization=normalization) C = Convolution(K, self.Rs_in, self.Rs_out) torch.manual_seed(seed) KC = KernelConv(self.Rs_in, self.Rs_out, RadialModel=ConstantRadialModel, normalization=normalization) return C, KC
def __init__(self, Rs_in, Rs_hidden, Rs_out, lmax, layers=3, max_radius=1.0, number_of_basis=3, radial_layers=3, feature_product=False, kernel=Kernel, convolution=Convolution): super().__init__() representations = [Rs_in] representations += [Rs_hidden] * layers representations += [Rs_out] RadialModel = partial(GaussianRadialModel, max_radius=max_radius, number_of_basis=number_of_basis, h=100, L=radial_layers, act=swish) K = partial(kernel, RadialModel=RadialModel, selection_rule=partial(o3.selection_rule_in_out_sh, lmax=lmax)) def make_layer(Rs_in, Rs_out): if feature_product: tp = TensorSquare(Rs_in, selection_rule=partial(o3.selection_rule, lmax=lmax)) lin = Linear(tp.Rs_out, Rs_in) act = GatedBlock(Rs_out, swish, sigmoid) conv = convolution(K, Rs_in, act.Rs_in) if feature_product: return torch.nn.ModuleList([tp, lin, conv, act]) return torch.nn.ModuleList([conv, act]) self.layers = torch.nn.ModuleList([ make_layer(Rs_layer_in, Rs_layer_out) for Rs_layer_in, Rs_layer_out in zip( representations[:-2], representations[1:-1]) ]) self.layers.append( Convolution(K, representations[-2], representations[-1])) self.feature_product = feature_product
def layer(Rs1, Rs2): R = partial(GaussianRadialModel, max_radius=max_radius, number_of_basis=radial_basis, h=radial_neurons, L=radial_layers, act=radial_act, min_radius=min_radius) k = Kernel(Rs1, Rs2, R, partial(o3.selection_rule_in_out_sh, lmax=lmax), allow_unused_inputs=True) return Convolution(k)
def make_layer(Rs_in, Rs_out): act = GatedBlock(Rs_out, relu, sigmoid) conv = Convolution(K, Rs_in, act.Rs_in) return torch.nn.ModuleList([conv, act])
Rs_in = [(1, 0), (2, 1)] # Input = One scalar plus two vectors Rs_out = [(1, 1)] # Output = One single vector # Radial model: R+ -> R^d RadialModel = partial(GaussianRadialModel, max_radius=3.0, number_of_basis=3, h=100, L=1, act=swish) # kernel: composed on a radial part that contains the learned parameters # and an angular part given by the spherical hamonics and the Clebsch-Gordan coefficients K = partial(Kernel, RadialModel=RadialModel) # Create the convolution module conv = Convolution(K(Rs_in, Rs_out)) # Module to compute the norm of each irreducible component norm = Norm(Rs_out) n = 5 # number of input points features = rs.randn(1, n, Rs_in, requires_grad=True) in_geometry = torch.randn(1, n, 3) out_geometry = torch.zeros(1, 1, 3) # One point at the origin out = norm(conv(features, in_geometry, out_geometry)) out.backward() print(out) print(features.grad)