def test(Rs, act): x = torch.randn(55, sum(2 * l + 1 for _, l, _ in Rs)) ac = S2Activation(Rs, act, 1000) y1 = ac(x, dim=-1) @ rs.rep(ac.Rs_out, 0, 0, 0, -1).T y2 = ac(x @ rs.rep(Rs, 0, 0, 0, -1).T, dim=-1) self.assertLess((y1 - y2).abs().max(), 1e-10)
def __init__(self, Rs_in, mul, lmax, Rs_out, size=5, layers=3): super().__init__() Rs = rs.simplify(Rs_in) Rs_out = rs.simplify(Rs_out) Rs_act = list(range(lmax + 1)) self.mul = mul self.layers = [] for _ in range(layers): conv = ImageConvolution(Rs, mul * Rs_act, size, lmax=lmax, fuzzy_pixels=True, padding=size // 2) # s2 nonlinearity act = S2Activation(Rs_act, swish, res=60) Rs = mul * act.Rs_out pool = LowPassFilter(scale=2.0, stride=2) self.layers += [torch.nn.ModuleList([conv, act, pool])] self.layers = torch.nn.ModuleList(self.layers) self.tail = LearnableTensorSquare(Rs, Rs_out)
def __init__(self, Rs_in, mul, lmax, Rs_out, layers=3): super().__init__() Rs = self.Rs_in = rs.simplify(Rs_in) self.Rs_out = rs.simplify(Rs_out) self.act = S2Activation(list(range(lmax + 1)), swish, res=20 * (lmax + 1)) self.layers = [] for _ in range(layers): lin = LearnableTensorSquare(Rs, mul * self.act.Rs_in, linear=True, allow_zero_outputs=True) # s2 nonlinearity Rs = mul * self.act.Rs_out self.layers += [lin] self.layers = torch.nn.ModuleList(self.layers) self.tail = LearnableTensorSquare(Rs, self.Rs_out)
def make_layer(Rs_in, Rs_out): act = S2Activation([(1, l, (-1)**l) for l in range(lmax + 1)], sigmoid, lmax_out=lmax, res=20 * (lmax + 1)) conv = convolution(K(Rs_in, Rs_out)) return torch.nn.ModuleList([conv, act])
def test(Rs, act): x = rs.randn(2, Rs) ac = S2Activation(Rs, act, 200, lmax_out=lmax + 1, random_rot=True) a, b, c, p = *torch.rand(3), 1 y1 = ac(x) @ rs.rep(ac.Rs_out, a, b, c, p).T y2 = ac(x @ rs.rep(Rs, a, b, c, p).T) self.assertLess((y1 - y2).abs().max(), 3e-4 * y1.abs().max())
def test(act, normalization): x = rs.randn(2, Rs, normalization=normalization) ac = S2Activation(Rs, act, 120, normalization=normalization, lmax_out=6, random_rot=True) a, b, c = o3.rand_angles() y1 = ac(x) @ rs.rep(ac.Rs_out, a, b, c, 1).T y2 = ac(x @ rs.rep(Rs, a, b, c, 1).T) self.assertLess((y1 - y2).abs().max(), 1e-10 * y1.abs().max())
def __init__(self, Rs_in, mul, lmax, Rs_out, layers=3): super().__init__() Rs = rs.simplify(Rs_in) Rs_out = rs.simplify(Rs_out) self.layers = [] for _ in range(layers): # tensor product: nonlinear and mixes the l's tp = TensorSquare(Rs, selection_rule=partial(o3.selection_rule, lmax=lmax)) # direct sum Rs = Rs + tp.Rs_out # linear: learned but don't mix l's Rs_act = [(1, l) for l in range(lmax + 1)] lin = Linear(Rs, mul * Rs_act, allow_unused_inputs=True) # s2 nonlinearity act = S2Activation(Rs_act, swish, res=20 * (lmax + 1)) Rs = mul * act.Rs_out self.layers += [torch.nn.ModuleList([tp, lin, act])] self.layers = torch.nn.ModuleList(self.layers) def lfilter(l): return l in [j for _, j, _ in Rs_out] tp = TensorSquare(Rs, selection_rule=partial(o3.selection_rule, lfilter=lfilter)) Rs = Rs + tp.Rs_out lin = Linear(Rs, Rs_out, allow_unused_inputs=True) self.tail = torch.nn.ModuleList([tp, lin])
def make_act(p_val, p_arg, act): Rs = [(1, l, p_val * p_arg**l) for l in range(lmax + 1)] return S2Activation(Rs, act, res=20 * (lmax + 1))