def __init__(self, Rs_in, Rs_out, linear=True, allow_change_output=False, allow_zero_outputs=False): super().__init__() self.Rs_in = rs.simplify(Rs_in) self.Rs_out = rs.simplify(Rs_out) ls = [l for _, l, _ in self.Rs_out] selection_rule = partial(o3.selection_rule, lfilter=lambda l: l in ls) if linear: Rs_in = [(1, 0, 1)] + self.Rs_in else: Rs_in = self.Rs_in self.linear = linear Rs_ts, T = rs.tensor_square(Rs_in, selection_rule) register_sparse_buffer(self, 'T', T) # [out, in1 * in2] ls = [l for _, l, _ in Rs_ts] if allow_change_output: self.Rs_out = [(mul, l, p) for mul, l, p in self.Rs_out if l in ls] elif not allow_zero_outputs: assert all(l in ls for _, l, _ in self.Rs_out) self.kernel = KernelLinear(Rs_ts, self.Rs_out) # [out, in, w]
def __init__(self, Rs_in, Rs_out, RadialModel, r, r_eps=0, selection_rule=o3.selection_rule_in_out_sh, normalization='component'): """ :param Rs_in: list of triplet (multiplicity, representation order, parity) :param Rs_out: list of triplet (multiplicity, representation order, parity) :param RadialModel: Class(d), trainable model: R -> R^d :param tensor r: [..., 3] :param float r_eps: distance considered as zero :param selection_rule: function of signature (l_in, p_in, l_out, p_out) -> [l_filter] :param sh: spherical harmonics function of signature ([l_filter], xyz[..., 3]) -> Y[m, ...] :param normalization: either 'norm' or 'component' representation order = nonnegative integer parity = 0 (no parity), 1 (even), -1 (odd) """ super().__init__() self.Rs_in = rs.convention(Rs_in) self.Rs_out = rs.convention(Rs_out) self.check_input_output(selection_rule) *self.size, xyz = r.size() assert xyz == 3 r = r.reshape(-1, 3) # [batch, space] self.register_buffer('radii', r.norm(2, dim=1)) # [batch] self.r_eps = r_eps self.tp = rs.TensorProduct(self.Rs_in, selection_rule, self.Rs_out, normalization, sorted=True) self.Rs_f = self.tp.Rs_in2 Y = rsh.spherical_harmonics_xyz( [(1, l, p) for _, l, p in self.Rs_f], r[self.radii > self.r_eps]) # [batch, l_filter * m_filter] # Normalize the spherical harmonics if normalization == 'component': Y.mul_(math.sqrt(4 * math.pi)) if normalization == 'norm': diag = math.sqrt(4 * math.pi) * torch.cat([ torch.ones(2 * l + 1) / math.sqrt(2 * l + 1) for _, l, _ in self.Rs_f ]) Y.mul_(diag) self.register_buffer('Y', Y) self.R = RadialModel(rs.mul_dim(self.Rs_f)) if (self.radii <= self.r_eps).any(): self.linear = KernelLinear(self.Rs_in, self.Rs_out) else: self.linear = None
def __init__(self, Rs_in, Rs_out, RadialModel, selection_rule=o3.selection_rule_in_out_sh, normalization='component', allow_unused_inputs=False, allow_zero_outputs=False): """ :param Rs_in: list of triplet (multiplicity, representation order, parity) :param Rs_out: list of triplet (multiplicity, representation order, parity) :param RadialModel: Class(d), trainable model: R -> R^d :param selection_rule: function of signature (l_in, p_in, l_out, p_out) -> [l_filter] :param sh: spherical harmonics function of signature ([l_filter], xyz[..., 3]) -> Y[m, ...] :param normalization: either 'norm' or 'component' representation order = nonnegative integer parity = 0 (no parity), 1 (even), -1 (odd) """ super().__init__() self.Rs_in = rs.convention(Rs_in) self.Rs_out = rs.convention(Rs_out) if not allow_unused_inputs: self.check_input(selection_rule) if not allow_zero_outputs: self.check_output(selection_rule) self.normalization = normalization self.tp = rs.TensorProduct(self.Rs_in, selection_rule, Rs_out, normalization, sorted=True) self.Rs_f = self.tp.Rs_in2 self.Ls = [l for _, l, _ in self.Rs_f] self.R = RadialModel(rs.mul_dim(self.Rs_f)) self.linear = KernelLinear(self.Rs_in, self.Rs_out)
def __init__(self, Rs_in, Rs_out, RadialModel, selection_rule=o3.selection_rule_in_out_sh, sh=rsh.spherical_harmonics_xyz, normalization='component'): """ :param Rs_in: list of triplet (multiplicity, representation order, parity) :param Rs_out: list of triplet (multiplicity, representation order, parity) :param RadialModel: Class(d), trainable model: R -> R^d :param selection_rule: function of signature (l_in, p_in, l_out, p_out) -> [l_filter] :param sh: spherical harmonics function of signature ([l_filter], xyz[..., 3]) -> Y[m, ...] :param normalization: either 'norm' or 'component' representation order = nonnegative integer parity = 0 (no parity), 1 (even), -1 (odd) """ super().__init__() self.Rs_in = rs.convention(Rs_in) self.Rs_out = rs.convention(Rs_out) self.check_input_output(selection_rule) Rs_f, Q = kernel_geometric(self.Rs_in, self.Rs_out, selection_rule, normalization) self.register_buffer('Q', Q) # [out, in, Y, R] self.sh = sh self.Ls = [l for _, l, _ in Rs_f] self.R = RadialModel(rs.mul_dim(Rs_f)) self.linear = KernelLinear(self.Rs_in, self.Rs_out)
def __init__(self, Rs_in, Rs_out, RadialModel, r, r_eps=0, selection_rule=o3.selection_rule_in_out_sh, sh=rsh.spherical_harmonics_xyz, normalization='component'): """ :param Rs_in: list of triplet (multiplicity, representation order, parity) :param Rs_out: list of triplet (multiplicity, representation order, parity) :param RadialModel: Class(d), trainable model: R -> R^d :param tensor r: [..., 3] :param float r_eps: distance considered as zero :param selection_rule: function of signature (l_in, p_in, l_out, p_out) -> [l_filter] :param sh: spherical harmonics function of signature ([l_filter], xyz[..., 3]) -> Y[m, ...] :param normalization: either 'norm' or 'component' representation order = nonnegative integer parity = 0 (no parity), 1 (even), -1 (odd) """ super().__init__() self.Rs_in = rs.convention(Rs_in) self.Rs_out = rs.convention(Rs_out) self.check_input_output(selection_rule) *self.size, xyz = r.size() assert xyz == 3 r = r.reshape(-1, 3) # [batch, space] self.radii = r.norm(2, dim=1) # [batch] self.r_eps = r_eps Rs_f, Q = kernel_geometric(self.Rs_in, self.Rs_out, selection_rule, normalization) Y = sh([l for _, l, _ in Rs_f], r[self.radii > self.r_eps]) # [batch, l_filter * m_filter] Q = torch.einsum('ijyw,zy->zijw', Q, Y) self.register_buffer('Q', Q) # [out, in, Y, R] self.R = RadialModel(rs.mul_dim(Rs_f)) if (self.radii <= self.r_eps).any(): self.linear = KernelLinear(self.Rs_in, self.Rs_out) else: self.linear = None
def __init__(self, Rs_in1, Rs_in2, Rs_out, allow_change_output=False): super().__init__() self.Rs_in1 = rs.simplify(Rs_in1) self.Rs_in2 = rs.simplify(Rs_in2) self.Rs_out = rs.simplify(Rs_out) ls = [l for _, l, _ in self.Rs_out] selection_rule = partial(o3.selection_rule, lfilter=lambda l: l in ls) Rs_ts, T = rs.tensor_product(self.Rs_in1, self.Rs_in2, selection_rule) register_sparse_buffer(self, 'T', T) # [out, in1 * in2] ls = [l for _, l, _ in Rs_ts] if allow_change_output: self.Rs_out = [(mul, l, p) for mul, l, p in self.Rs_out if l in ls] else: assert all(l in ls for _, l, _ in self.Rs_out) self.kernel = KernelLinear(Rs_ts, self.Rs_out) # [out, in, w]