def __init__(self, Rs_in, Rs_out, Rs_mid1, Rs_mid2, groups, convolution, linear=Linear, scalar_activation=swish, gate_activation=sigmoid, final_nonlinearity=True): super().__init__() # Linear with GatedBlock scalars = [(mul, l, p) for (mul, l, p) in groups * Rs_mid1 if l == 0] act_scalars = [(mul, scalar_activation if p == 1 else tanh) for mul, l, p in scalars] nonscalars = [(mul, l, p) for (mul, l, p) in groups * Rs_mid1 if l > 0] gates = [(rs.mul_dim(nonscalars), 0, +1)] act_gates = [(-1, gate_activation)] act_in = GatedBlockParity(scalars, act_scalars, gates, act_gates, nonscalars) self.lin_in = linear(Rs_in, act_in.Rs_in) self.act_in = act_in # Kernel with GatedBlock scalars = [(mul, l, p) for (mul, l, p) in Rs_mid2 if l == 0] act_scalars = [(mul, scalar_activation if p == 1 else tanh) for mul, l, p in scalars] nonscalars = [(mul, l, p) for (mul, l, p) in Rs_mid2 if l > 0] gates = [(rs.mul_dim(nonscalars), 0, +1)] act_gates = [(-1, gate_activation)] act_mid = GatedBlockParity(scalars, act_scalars, gates, act_gates, nonscalars) self.conv = convolution(Rs_mid1, act_mid.Rs_in) self.act_mid = act_mid # Linear with or without GatedBlock if final_nonlinearity: scalars = [(mul, l, p) for (mul, l, p) in Rs_out if l == 0] act_scalars = [(mul, scalar_activation if p == 1 else tanh) for mul, l, p in scalars] nonscalars = [(mul, l, p) for (mul, l, p) in Rs_out if l > 0] gates = [(rs.mul_dim(nonscalars), 0, +1)] act_gates = [(-1, gate_activation)] act_out = GatedBlockParity(scalars, act_scalars, gates, act_gates, nonscalars) self.lin_out = linear(groups * Rs_mid2, act_out.Rs_in) self.act_out = act_out else: self.lin_out = linear(groups * Rs_mid2, Rs_out) self.act_out = None self.groups = groups
def __init__(self, Rs_in, mul, Rs_out, lmax, layers=3, max_radius=1.0, number_of_basis=3, radial_layers=3, kernel=Kernel, convolution=Convolution, min_radius=0.0): super().__init__() R = partial(GaussianRadialModel, max_radius=max_radius, number_of_basis=number_of_basis, h=100, L=radial_layers, act=swish, min_radius=min_radius) K = partial(kernel, RadialModel=R, selection_rule=partial(o3.selection_rule_in_out_sh, lmax=lmax)) modules = [] Rs = rs.convention(Rs_in) for _ in range(layers): scalars = [(mul, l, p) for mul, l, p in [(mul, 0, +1), (mul, 0, -1)] if rs.haslinearpath(Rs, l, p)] act_scalars = [(mul, swish if p == 1 else tanh) for mul, l, p in scalars] nonscalars = [(mul, l, p) for l in range(1, lmax + 1) for p in [+1, -1] if rs.haslinearpath(Rs, l, p)] if rs.haslinearpath(Rs, 0, +1): gates = [(rs.mul_dim(nonscalars), 0, +1)] act_gates = [(-1, sigmoid)] else: gates = [(rs.mul_dim(nonscalars), 0, -1)] act_gates = [(-1, tanh)] act = GatedBlockParity(scalars, act_scalars, gates, act_gates, nonscalars) conv = convolution(K(Rs, act.Rs_in)) Rs = act.Rs_out block = torch.nn.ModuleList([conv, act]) modules.append(block) self.layers = torch.nn.ModuleList(modules) K = partial(K, allow_unused_inputs=True) self.layers.append(convolution(K(Rs, Rs_out)))
def make_gated_block(Rs_in, mul=16, lmax=3): scalars = [(mul, l, p) for mul, l, p in [(mul, 0, +1), (mul, 0, -1)] if rs.haslinearpath(Rs_in, l, p)] act_scalars = [(mul, swish if p == 1 else tanh) for mul, l, p in scalars] nonscalars = [(mul, l, p) for l in range(1, lmax + 1) for p in [+1, -1] if rs.haslinearpath(Rs_in, l, p)] if rs.haslinearpath(Rs_in, 0, +1): gates = [(rs.mul_dim(nonscalars), 0, +1)] act_gates = [(-1, sigmoid)] else: gates = [(rs.mul_dim(nonscalars), 0, -1)] act_gates = [(-1, tanh)] return GatedBlockParity(scalars, act_scalars, gates, act_gates, nonscalars)
def __init__(self, Rs_in, Rs_out, RadialModel, r, r_eps=0, selection_rule=o3.selection_rule_in_out_sh, normalization='component'): """ :param Rs_in: list of triplet (multiplicity, representation order, parity) :param Rs_out: list of triplet (multiplicity, representation order, parity) :param RadialModel: Class(d), trainable model: R -> R^d :param tensor r: [..., 3] :param float r_eps: distance considered as zero :param selection_rule: function of signature (l_in, p_in, l_out, p_out) -> [l_filter] :param sh: spherical harmonics function of signature ([l_filter], xyz[..., 3]) -> Y[m, ...] :param normalization: either 'norm' or 'component' representation order = nonnegative integer parity = 0 (no parity), 1 (even), -1 (odd) """ super().__init__() self.Rs_in = rs.convention(Rs_in) self.Rs_out = rs.convention(Rs_out) self.check_input_output(selection_rule) *self.size, xyz = r.size() assert xyz == 3 r = r.reshape(-1, 3) # [batch, space] self.register_buffer('radii', r.norm(2, dim=1)) # [batch] self.r_eps = r_eps self.tp = rs.TensorProduct(self.Rs_in, selection_rule, self.Rs_out, normalization, sorted=True) self.Rs_f = self.tp.Rs_in2 Y = rsh.spherical_harmonics_xyz( [(1, l, p) for _, l, p in self.Rs_f], r[self.radii > self.r_eps]) # [batch, l_filter * m_filter] # Normalize the spherical harmonics if normalization == 'component': Y.mul_(math.sqrt(4 * math.pi)) if normalization == 'norm': diag = math.sqrt(4 * math.pi) * torch.cat([ torch.ones(2 * l + 1) / math.sqrt(2 * l + 1) for _, l, _ in self.Rs_f ]) Y.mul_(diag) self.register_buffer('Y', Y) self.R = RadialModel(rs.mul_dim(self.Rs_f)) if (self.radii <= self.r_eps).any(): self.linear = KernelLinear(self.Rs_in, self.Rs_out) else: self.linear = None
def plot_data_on_grid(box_length, radial, Rs, sh=o3.spherical_harmonics_xyz, n=30): L_to_index = {} set_of_L = set([L for mul, L in Rs]) start = 0 for L in set_of_L: L_to_index[L] = [start, start + 2 * L + 1] start += 2 * L + 1 r = np.mgrid[-1:1:n * 1j, -1:1:n * 1j, -1:1:n * 1j].reshape(3, -1) r = r.transpose(1, 0) r *= box_length / 2. r = torch.from_numpy(r) Ys = sh(set_of_L, r) R = radial(r.norm(2, -1)).detach() # [r_values, n_r_filters] assert R.shape[-1] == rs.mul_dim(Rs) R_helper = torch.zeros(R.shape[-1], rs.dim(Rs)) mul_start = 0 y_start = 0 Ys_indices = [] for mul, L in Rs: Ys_indices += list(range(L_to_index[L][0], L_to_index[L][1])) * mul R_helper = rs.map_mul_to_Rs(Rs) full_Ys = Ys[Ys_indices] # [values, rs.dim(Rs)]] full_Ys = full_Ys.reshape(full_Ys.shape[0], -1) all_f = torch.einsum('xn,dn,dx->xd', R, R_helper, full_Ys) return r, all_f
def __init__(self, Rs_in, Rs_out, RadialModel, selection_rule=o3.selection_rule_in_out_sh, normalization='component', allow_unused_inputs=False, allow_zero_outputs=False): """ :param Rs_in: list of triplet (multiplicity, representation order, parity) :param Rs_out: list of triplet (multiplicity, representation order, parity) :param RadialModel: Class(d), trainable model: R -> R^d :param selection_rule: function of signature (l_in, p_in, l_out, p_out) -> [l_filter] :param sh: spherical harmonics function of signature ([l_filter], xyz[..., 3]) -> Y[m, ...] :param normalization: either 'norm' or 'component' representation order = nonnegative integer parity = 0 (no parity), 1 (even), -1 (odd) """ super().__init__() self.Rs_in = rs.convention(Rs_in) self.Rs_out = rs.convention(Rs_out) if not allow_unused_inputs: self.check_input(selection_rule) if not allow_zero_outputs: self.check_output(selection_rule) self.normalization = normalization self.tp = rs.TensorProduct(self.Rs_in, selection_rule, Rs_out, normalization, sorted=True) self.Rs_f = self.tp.Rs_in2 self.Ls = [l for _, l, _ in self.Rs_f] self.R = RadialModel(rs.mul_dim(self.Rs_f)) self.linear = KernelLinear(self.Rs_in, self.Rs_out)
def __init__(self, Rs, activation, normalization='component'): super().__init__() self.Rs = rs.convention(Rs) self.activation = activation self.norm = Norm(self.Rs, normalization) self.bias = torch.nn.Parameter(torch.zeros(rs.mul_dim(self.Rs)))
def __init__(self, Rs_in, Rs_out, RadialModel, selection_rule=o3.selection_rule_in_out_sh, sh=rsh.spherical_harmonics_xyz, normalization='component'): """ :param Rs_in: list of triplet (multiplicity, representation order, parity) :param Rs_out: list of triplet (multiplicity, representation order, parity) :param RadialModel: Class(d), trainable model: R -> R^d :param selection_rule: function of signature (l_in, p_in, l_out, p_out) -> [l_filter] :param sh: spherical harmonics function of signature ([l_filter], xyz[..., 3]) -> Y[m, ...] :param normalization: either 'norm' or 'component' representation order = nonnegative integer parity = 0 (no parity), 1 (even), -1 (odd) """ super().__init__() self.Rs_in = rs.convention(Rs_in) self.Rs_out = rs.convention(Rs_out) self.check_input_output(selection_rule) Rs_f, Q = kernel_geometric(self.Rs_in, self.Rs_out, selection_rule, normalization) self.register_buffer('Q', Q) # [out, in, Y, R] self.sh = sh self.Ls = [l for _, l, _ in Rs_f] self.R = RadialModel(rs.mul_dim(Rs_f)) self.linear = KernelLinear(self.Rs_in, self.Rs_out)
def __init__(self, Rs_in, mul, Rs_out, lmax, size=5, layers=3): super().__init__() modules = [] Rs = rs.convention(Rs_in) for _ in range(layers): scalars = [(mul, l, p) for mul, l, p in [(mul, 0, +1), (mul, 0, -1)] if rs.haslinearpath(Rs, l, p)] act_scalars = [(mul, swish if p == 1 else tanh) for mul, l, p in scalars] nonscalars = [(mul, l, p) for l in range(1, lmax + 1) for p in [+1, -1] if rs.haslinearpath(Rs, l, p)] gates = [(rs.mul_dim(nonscalars), 0, +1)] if rs.haslinearpath(Rs, 0, +1): gates = [(rs.mul_dim(nonscalars), 0, +1)] act_gates = [(-1, sigmoid)] else: gates = [(rs.mul_dim(nonscalars), 0, -1)] act_gates = [(-1, tanh)] act = GatedBlockParity(scalars, act_scalars, gates, act_gates, nonscalars) conv = Convolution(Rs, act.Rs_in, size, lmax=lmax, fuzzy_pixels=True, padding=size // 2) Rs = act.Rs_out block = torch.nn.Sequential(conv, act) modules.append(block) modules += [ Convolution(Rs, Rs_out, size, lmax=lmax, fuzzy_pixels=True, padding=size // 2, allow_unused_inputs=True) ] self.layers = torch.nn.Sequential(*modules)
def test_mul_and_dot(): lmax = 4 signal1 = torch.zeros((lmax + 1)**2) signal2 = signal1.clone() signal1[0] = 1. signal2[3] = 1. sph1 = SphericalTensor(signal1) sph2 = SphericalTensor(signal2) new_sph = sph1 * sph2 assert rs.are_equal(new_sph.Rs, [(rs.mul_dim(sph1.Rs), 0, 0)]) sph1.dot(sph2)
def __init__(self, Rs_in, Rs_out, RadialModel, r, r_eps=0, selection_rule=o3.selection_rule_in_out_sh, sh=rsh.spherical_harmonics_xyz, normalization='component'): """ :param Rs_in: list of triplet (multiplicity, representation order, parity) :param Rs_out: list of triplet (multiplicity, representation order, parity) :param RadialModel: Class(d), trainable model: R -> R^d :param tensor r: [..., 3] :param float r_eps: distance considered as zero :param selection_rule: function of signature (l_in, p_in, l_out, p_out) -> [l_filter] :param sh: spherical harmonics function of signature ([l_filter], xyz[..., 3]) -> Y[m, ...] :param normalization: either 'norm' or 'component' representation order = nonnegative integer parity = 0 (no parity), 1 (even), -1 (odd) """ super().__init__() self.Rs_in = rs.convention(Rs_in) self.Rs_out = rs.convention(Rs_out) self.check_input_output(selection_rule) *self.size, xyz = r.size() assert xyz == 3 r = r.reshape(-1, 3) # [batch, space] self.radii = r.norm(2, dim=1) # [batch] self.r_eps = r_eps Rs_f, Q = kernel_geometric(self.Rs_in, self.Rs_out, selection_rule, normalization) Y = sh([l for _, l, _ in Rs_f], r[self.radii > self.r_eps]) # [batch, l_filter * m_filter] Q = torch.einsum('ijyw,zy->zijw', Q, Y) self.register_buffer('Q', Q) # [out, in, Y, R] self.R = RadialModel(rs.mul_dim(Rs_f)) if (self.radii <= self.r_eps).any(): self.linear = KernelLinear(self.Rs_in, self.Rs_out) else: self.linear = None
def test_norm(Rs, normalization): m = Norm(Rs, normalization=normalization) x = rs.randn(2, Rs, normalization=normalization) x = m(x) assert x.shape == (2, rs.mul_dim(Rs))
def __init__(self, Rs_in, mul, Rs_out, lmax, layers=3, max_radius=1.0, number_of_basis=3, radial_layers=3, feature_product=False, kernel=Kernel, convolution=Convolution, min_radius=0.0): super().__init__() R = partial(GaussianRadialModel, max_radius=max_radius, number_of_basis=number_of_basis, h=100, L=radial_layers, act=swish, min_radius=min_radius) K = partial(kernel, RadialModel=R, selection_rule=partial(o3.selection_rule_in_out_sh, lmax=lmax)) modules = [] Rs = Rs_in for _ in range(layers): scalars = [(mul, l, p) for mul, l, p in [(mul, 0, +1), (mul, 0, -1)] if rs.haslinearpath(Rs, l, p)] act_scalars = [(mul, swish if p == 1 else tanh) for mul, l, p in scalars] nonscalars = [(mul, l, p) for l in range(1, lmax + 1) for p in [+1, -1] if rs.haslinearpath(Rs, l, p)] gates = [(rs.mul_dim(nonscalars), 0, +1)] act_gates = [(-1, sigmoid)] act = GatedBlockParity(scalars, act_scalars, gates, act_gates, nonscalars) conv = convolution(K(Rs, act.Rs_in)) if feature_product: tr1 = rs.TransposeToMulL(act.Rs_out) lts = LearnableTensorSquare(tr1.Rs_out, [(1, l, p) for l in range(lmax + 1) for p in [-1, 1]], allow_change_output=True) tr2 = torch.nn.Flatten(2) act = torch.nn.Sequential(act, tr1, lts, tr2) Rs = tr1.mul * lts.Rs_out else: Rs = act.Rs_out block = torch.nn.ModuleList([conv, act]) modules.append(block) self.layers = torch.nn.ModuleList(modules) K = partial(K, allow_unused_inputs=True) self.layers.append(convolution(K(Rs, Rs_out))) self.feature_product = feature_product
def test_mul_dim(): Rs = [(1, 0), (3, 1), (2, 2)] assert rs.mul_dim(Rs) == 6 Rs = [(1, 0), (3, 0), (2, 0)] assert rs.mul_dim(Rs) == 6
def test_mul_dimRs(self): Rs = [(1, 0), (3, 1), (2, 2)] self.assertTrue(rs.mul_dim(Rs) == 6) Rs = [(1, 0), (3, 0), (2, 0)] self.assertTrue(rs.mul_dim(Rs) == 6)