def __init__(self, Rs_in, Rs_out, Rs_sh, RadialModel, groups=math.inf, normalization='component'): super().__init__(aggr='add', flow='target_to_source') self.Rs_in = rs.simplify(Rs_in) self.Rs_out = rs.simplify(Rs_out) self.lin1 = Linear(Rs_in, Rs_out, allow_unused_inputs=True, allow_zero_outputs=True) self.tp = GroupedWeightedTensorProduct(Rs_in, Rs_sh, Rs_out, groups=groups, normalization=normalization, own_weight=False) self.rm = RadialModel(self.tp.nweight) self.lin2 = Linear(Rs_out, Rs_out) self.Rs_sh = Rs_sh self.normalization = normalization
def test1(self): torch.set_default_dtype(torch.float64) Rs_in = [(5, 0), (20, 1), (15, 0), (20, 2)] Rs_out = [(5, 0), (10, 1), (10, 2), (5, 0)] with torch.no_grad(): lin = Linear(Rs_in, Rs_out) features = torch.randn(10000, rs.dim(Rs_in)) features = lin(features) bins, left, right = 100, -4, 4 bin_width = (right - left) / (bins - 1) x = torch.linspace(left, right, bins) p = torch.histc(features, bins, left, right) / features.numel() / bin_width q = x.pow(2).div(-2).exp().div(math.sqrt(2 * math.pi)) # Normal law # import matplotlib.pyplot as plt # plt.plot(x, p) # plt.plot(x, q) # plt.show() Dkl = ((p + 1e-100) / q).log().mul(p).sum() # Kullback-Leibler divergence of P || Q self.assertLess(Dkl, 0.1)
def __init__(self, Rs_mid_1, Rs_mid_2, mul_mid, Rs_out, get_l_mul=o3.selection_rule): super().__init__() self.mul_mid = mul_mid self.m = TensorProduct(Rs_mid_1, Rs_mid_2, get_l_mul) self.si = Linear(mul_mid * self.m.Rs_out, Rs_out)
def make_layer(Rs_in, Rs_out): if feature_product: tp = TensorSquare(Rs_in, selection_rule=partial(o3.selection_rule, lmax=lmax)) lin = Linear(tp.Rs_out, Rs_in) act = GatedBlock(Rs_out, swish, sigmoid) conv = convolution(K, Rs_in, act.Rs_in) if feature_product: return torch.nn.ModuleList([tp, lin, conv, act]) return torch.nn.ModuleList([conv, act])
def __init__(self, Rs_in, Rs_out, lmax=3): super().__init__(aggr='add', flow='target_to_source') RadialModel = partial( GaussianRadialModel, max_radius=1.2, min_radius=0.0, number_of_basis=3, h=100, L=2, act=swish ) Rs_sh = [(1, l, (-1)**l) for l in range(0, lmax + 1)] self.Rs_in = rs.simplify(Rs_in) self.Rs_out = rs.simplify(Rs_out) self.lin1 = Linear(Rs_in, Rs_out, allow_unused_inputs=True, allow_zero_outputs=True) self.tp = GroupedWeightedTensorProduct(Rs_in, Rs_sh, Rs_out, own_weight=False) self.rm = RadialModel(self.tp.nweight) self.lin2 = Linear(Rs_out, Rs_out) self.Rs_sh = Rs_sh
def test_equiv(self): torch.set_default_dtype(torch.float64) Rs_in = [(5, 0), (15, 1), (5, 0), (10, 2)] Rs_out = [(2, 0), (1, 1), (1, 2), (3, 0)] lin = Linear(Rs_in, Rs_out) f_in = torch.randn(100, rs.dim(Rs_in)) angles = torch.randn(3) y1 = lin(torch.einsum('ij,zj->zi', rs.rep(Rs_in, *angles), f_in)) y2 = torch.einsum('ij,zj->zi', rs.rep(Rs_out, *angles), lin(f_in)) self.assertLess((y1 - y2).abs().max(), 1e-10)
def __init__(self, Rs_in, mul, lmax, Rs_out, layers=3): super().__init__() Rs = rs.simplify(Rs_in) Rs_out = rs.simplify(Rs_out) self.layers = [] for _ in range(layers): # tensor product: nonlinear and mixes the l's tp = TensorSquare(Rs, selection_rule=partial(o3.selection_rule, lmax=lmax)) # direct sum Rs = Rs + tp.Rs_out # linear: learned but don't mix l's Rs_act = [(1, l) for l in range(lmax + 1)] lin = Linear(Rs, mul * Rs_act, allow_unused_inputs=True) # s2 nonlinearity act = S2Activation(Rs_act, swish, res=20 * (lmax + 1)) Rs = mul * act.Rs_out self.layers += [torch.nn.ModuleList([tp, lin, act])] self.layers = torch.nn.ModuleList(self.layers) def lfilter(l): return l in [j for _, j, _ in Rs_out] tp = TensorSquare(Rs, selection_rule=partial(o3.selection_rule, lfilter=lfilter)) Rs = Rs + tp.Rs_out lin = Linear(Rs, Rs_out, allow_unused_inputs=True) self.tail = torch.nn.ModuleList([tp, lin])
def __init__(self, Rs_in, Rs_out): super().__init__(aggr='add', flow='target_to_source') RadialModel = partial(GaussianRadialModel, max_radius=1.2, min_radius=0.0, number_of_basis=3, h=100, L=2, act=swish) Rs_sh = [(1, l, (-1)**l) for l in range(0, 3 + 1)] self.tp = GroupedWeightedTensorProduct(Rs_in, Rs_sh, Rs_out, groups=4, own_weight=False) self.rm = RadialModel(self.tp.nweight) self.lin = Linear(Rs_out, Rs_out) self.Rs_sh = Rs_sh
def __init__(self, Rs_mid_1, Rs_mid_2, mul_mid, Rs_out, selection_rule=o3.selection_rule): super().__init__() self.mul_mid = mul_mid self.tp = TensorProduct(Rs_mid_1, Rs_mid_2, selection_rule) self.lin = Linear(mul_mid * self.tp.Rs_out, Rs_out)