def __init__( self, irreps_in, irreps_out, internal_weights=None, shared_weights=None, ): super().__init__() irreps_in = o3.Irreps(irreps_in) irreps_out = o3.Irreps(irreps_out) instr = [(i_in, 0, i_out, "uvw", True, 1.0) for i_in, (_, ir_in) in enumerate(irreps_in) for i_out, (_, ir_out) in enumerate(irreps_out) if ir_in == ir_out] self.tp = o3.TensorProduct( irreps_in, "0e", irreps_out, instr, internal_weights=internal_weights, shared_weights=shared_weights, ) self.output_mask = self.tp.output_mask self.irreps_in = irreps_in self.irreps_out = irreps_out
def __init__(self, irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated): super().__init__() irreps_scalars = o3.Irreps(irreps_scalars) irreps_gates = o3.Irreps(irreps_gates) irreps_gated = o3.Irreps(irreps_gated) if len(irreps_gates) > 0 and irreps_gates.lmax > 0: raise ValueError( f"Gate scalars must be scalars, instead got irreps_gates = {irreps_gates}" ) if len(irreps_scalars) > 0 and irreps_scalars.lmax > 0: raise ValueError( f"Scalars must be scalars, instead got irreps_scalars = {irreps_scalars}" ) if irreps_gates.num_irreps != irreps_gated.num_irreps: raise ValueError( f"There are {irreps_gated.num_irreps} irreps in irreps_gated, but a different number ({irreps_gates.num_irreps}) of gate scalars in irreps_gates" ) self.sc = _Sortcut(irreps_scalars, irreps_gates, irreps_gated) self.irreps_scalars, self.irreps_gates, self.irreps_gated = self.sc.irreps_outs self._irreps_in = self.sc.irreps_in self.act_scalars = Activation(irreps_scalars, act_scalars) irreps_scalars = self.act_scalars.irreps_out self.act_gates = Activation(irreps_gates, act_gates) irreps_gates = self.act_gates.irreps_out self.mul = o3.ElementwiseTensorProduct(irreps_gated, irreps_gates) irreps_gated = self.mul.irreps_out self._irreps_out = irreps_scalars + irreps_gated
def test_bias(): irreps_in = o3.Irreps("2x0e + 1e + 2x0e + 0o") irreps_out = o3.Irreps("3x0e + 1e + 3x0e + 5x0e + 0o") m = o3.Linear(irreps_in, irreps_out, biases=[True, False, False, True, False]) with torch.no_grad(): m.bias[:].fill_(1.0) x = m(torch.zeros(irreps_in.dim)) assert torch.allclose( x, torch.tensor([ 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0 ])) assert_equivariant(m) assert_auto_jitable(m) m = o3.Linear("0e + 0o + 1e + 1o", "10x0e + 0o + 1e + 1o", biases=True) assert_equivariant(m) assert_auto_jitable(m) assert_normalized(m, n_weight=100, n_input=10_000, atol=0.5, weights=[m.weight])
def __init__(self, irreps_out, num_z, lmax) -> None: super().__init__() self.num_z = num_z self.irreps_sh = o3.Irreps.spherical_harmonics(lmax) # to multiply the edge type one-hot with the spherical harmonics to get the edge attributes self.mul = TensorProduct( [(num_z**2, "0e")], self.irreps_sh, [(num_z**2, ir) for _, ir in self.irreps_sh], [ (0, l, l, "uvu", False) for l in range(lmax + 1) ] ) irreps_attr = self.mul.irreps_out irreps_mid = o3.Irreps("64x0e + 24x1e + 24x1o + 16x2e + 16x2o") irreps_out = o3.Irreps(irreps_out) self.tp1 = FullyConnectedTensorProduct( irreps_in1=self.irreps_sh, irreps_in2=irreps_attr, irreps_out=irreps_mid, ) self.tp2 = FullyConnectedTensorProduct( irreps_in1=irreps_mid, irreps_in2=irreps_attr, irreps_out=irreps_out, )
def network(): num_nodes = 5 irreps_in = o3.Irreps("3x0e + 2x1o") irreps_attr = o3.Irreps("10x0e") irreps_out = o3.Irreps("2x0o + 2x1o + 2x2e") f = Network( irreps_in, o3.Irreps("5x0e + 5x0o + 5x1e + 5x1o"), irreps_out, irreps_attr, o3.Irreps.spherical_harmonics(3), layers=3, max_radius=2.0, number_of_basis=5, radial_layers=2, radial_neurons=100, num_neighbors=4.0, num_nodes=num_nodes, ) def random_graph(): N = random.randint(3, 7) return { 'pos': torch.randn(N, 3), 'x': f.irreps_in.randn(N, -1), 'z': f.irreps_node_attr.randn(N, -1) } return f, random_graph
def test_full(): irreps_in1 = o3.Irreps("1e + 2e + 3x3o") irreps_in2 = o3.Irreps("1e + 2x2e + 2x3o") m = FullTensorProduct(irreps_in1, irreps_in2) print(m) assert_equivariant(m) assert_auto_jitable(m)
def __init__(self, irreps_node_input, irreps_node_attr, irreps_edge_attr, irreps_node_output, fc_neurons, num_neighbors) -> None: super().__init__() self.irreps_node_input = o3.Irreps(irreps_node_input) self.irreps_node_attr = o3.Irreps(irreps_node_attr) self.irreps_edge_attr = o3.Irreps(irreps_edge_attr) self.irreps_node_output = o3.Irreps(irreps_node_output) self.num_neighbors = num_neighbors self.sc = FullyConnectedTensorProduct(self.irreps_node_input, self.irreps_node_attr, self.irreps_node_output) self.lin1 = FullyConnectedTensorProduct(self.irreps_node_input, self.irreps_node_attr, self.irreps_node_input) irreps_mid = [] instructions = [] for i, (mul, ir_in) in enumerate(self.irreps_node_input): for j, (_, ir_edge) in enumerate(self.irreps_edge_attr): for ir_out in ir_in * ir_edge: if ir_out in self.irreps_node_output or ir_out == o3.Irrep( 0, 1): k = len(irreps_mid) irreps_mid.append((mul, ir_out)) instructions.append((i, j, k, 'uvu', True)) irreps_mid = o3.Irreps(irreps_mid) irreps_mid, p, _ = irreps_mid.sort() assert irreps_mid.dim > 0, f"irreps_node_input={self.irreps_node_input} time irreps_edge_attr={self.irreps_edge_attr} produces nothing in irreps_node_output={self.irreps_node_output}" instructions = [(i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instructions] tp = TensorProduct( self.irreps_node_input, self.irreps_edge_attr, irreps_mid, instructions, internal_weights=False, shared_weights=False, ) self.fc = FullyConnectedNet(fc_neurons + [tp.weight_numel], torch.nn.functional.silu) self.tp = tp self.lin2 = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, self.irreps_node_output) # inspired by https://arxiv.org/pdf/2002.10444.pdf self.alpha = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, "0e") with torch.no_grad(): self.alpha.weight.zero_() assert self.alpha.output_mask[ 0] == 1.0, f"irreps_mid={irreps_mid} and irreps_node_attr={self.irreps_node_attr} are not able to generate scalars"
def test_linear(): irreps_in = o3.Irreps("1e + 2e + 3x3o") irreps_out = o3.Irreps("1e + 2e + 3x3o") m = o3.Linear(irreps_in, irreps_out) m(torch.randn(irreps_in.dim)) assert_equivariant(m) assert_auto_jitable(m) assert_normalized(m, n_weight=100, n_input=10_000, atol=0.5)
def test_cat(): irreps = o3.Irreps("4x1e + 6x2e + 12x2o") + o3.Irreps( "1x1e + 2x2e + 12x4o") assert len(irreps) == 6 assert irreps.ls == [1] * 4 + [2] * 6 + [2] * 12 + [1] * 1 + [2] * 2 + [ 4 ] * 12 assert irreps.lmax == 4 assert irreps.num_irreps == 4 + 6 + 12 + 1 + 2 + 12
def test_getitem(): irreps = o3.Irreps("16x1e + 3e + 2e + 5o") assert irreps[0] == (16, o3.Irrep("1e")) assert irreps[3] == (1, o3.Irrep("5o")) assert irreps[-1] == (1, o3.Irrep("5o")) sliced = irreps[2:] assert isinstance(sliced, o3.Irreps) assert sliced == o3.Irreps("2e + 5o")
def test_id(): irreps_in = o3.Irreps("1e + 2e + 3x3o") irreps_out = o3.Irreps("1e + 2e + 3x3o") m = Identity(irreps_in, irreps_out) print(m) m(torch.randn(irreps_in.dim)) assert_equivariant(m) assert_auto_jitable(m, strict_shapes=False)
def tp_path_exists(irreps_in1, irreps_in2, ir_out): irreps_in1 = o3.Irreps(irreps_in1).simplify() irreps_in2 = o3.Irreps(irreps_in2).simplify() ir_out = o3.Irrep(ir_out) for _, ir1 in irreps_in1: for _, ir2 in irreps_in2: if ir_out in ir1 * ir2: return True return False
def test_assert_equivariant(): def not_equivariant(x1, x2): return x1*x2 not_equivariant.irreps_in1 = o3.Irreps("2x0e + 1x1e + 3x2o + 1x4e") not_equivariant.irreps_in2 = o3.Irreps("2x0o + 3x0o + 3x2e + 1x4o") not_equivariant.irreps_out = o3.Irreps("1x1e + 2x0o + 3x2e + 1x4o") assert not_equivariant.irreps_in1.dim == not_equivariant.irreps_in2.dim assert not_equivariant.irreps_in1.dim == not_equivariant.irreps_out.dim with pytest.raises(AssertionError): assert_equivariant(not_equivariant)
def test_fully_connected(): irreps_in1 = o3.Irreps("1e + 2e + 3x3o") irreps_in2 = o3.Irreps("1e + 2e + 3x3o") irreps_out = o3.Irreps("1e + 2e + 3x3o") m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out) print(m) m(torch.randn(irreps_in1.dim), torch.randn(irreps_in2.dim)) assert_equivariant(m) assert_auto_jitable(m)
def __init__(self, irreps_in, irreps_out): super().__init__() self.irreps_in = o3.Irreps(irreps_in).simplify() self.irreps_out = o3.Irreps(irreps_out).simplify() assert self.irreps_in == self.irreps_out output_mask = torch.cat([ torch.ones(mul * (2 * l + 1)) for mul, (l, _p) in self.irreps_out ]) self.register_buffer('output_mask', output_mask)
def __init__(self, irreps_in, irreps_out, irreps_sh, diameter, num_radial_basis, steps=(1.0, 1.0, 1.0), **kwargs): super().__init__() self.irreps_in = o3.Irreps(irreps_in) self.irreps_out = o3.Irreps(irreps_out) self.irreps_sh = o3.Irreps(irreps_sh) self.num_radial_basis = num_radial_basis # self-connection self.sc = Linear(self.irreps_in, self.irreps_out) # connection with neighbors r = diameter / 2 s = math.floor(r / steps[0]) x = torch.arange(-s, s + 1.0) * steps[0] s = math.floor(r / steps[1]) y = torch.arange(-s, s + 1.0) * steps[1] s = math.floor(r / steps[2]) z = torch.arange(-s, s + 1.0) * steps[2] lattice = torch.stack(torch.meshgrid(x, y, z), dim=-1) # [x, y, z, R^3] self.register_buffer('lattice', lattice) if 'padding' not in kwargs: kwargs['padding'] = tuple(s // 2 for s in lattice.shape[:3]) self.kwargs = kwargs emb = soft_one_hot_linspace( x=lattice.norm(dim=-1), start=0.0, end=r, number=self.num_radial_basis, basis='smooth_finite', cutoff=True, ) self.register_buffer('emb', emb) sh = o3.spherical_harmonics( l=self.irreps_sh, x=lattice, normalize=True, normalization='component' ) # [x, y, z, irreps_sh.dim] self.register_buffer('sh', sh) self.tp = FullyConnectedTensorProduct(self.irreps_in, self.irreps_sh, self.irreps_out, shared_weights=False) self.weight = torch.nn.Parameter(torch.randn(self.num_radial_basis, self.tp.weight_numel))
def __init__(self, irreps: o3.Irreps, act, res, normalization='component', lmax_out=None, random_rot=False): super().__init__() irreps = o3.Irreps(irreps).simplify() _, (_, p_val) = irreps[0] _, (lmax, _) = irreps[-1] assert all(mul == 1 for mul, _ in irreps) assert irreps.ls == list(range(lmax + 1)) if all(p == p_val for _, (l, p) in irreps): p_arg = 1 elif all(p == p_val * (-1)**l for _, (l, p) in irreps): p_arg = -1 else: assert False, "the parity of the input is not well defined" self.irreps_in = irreps # the input transforms as : A_l ---> p_val * (p_arg)^l * A_l # the sphere signal transforms as : f(r) ---> p_val * f(p_arg * r) if lmax_out is None: lmax_out = lmax if p_val in (0, +1): self.irreps_out = o3.Irreps([(1, (l, p_val * p_arg**l)) for l in range(lmax_out + 1)]) if p_val == -1: x = torch.linspace(0, 10, 256) a1, a2 = act(x), act(-x) if (a1 - a2).abs().max() < a1.abs().max() * 1e-10: # p_act = 1 self.irreps_out = o3.Irreps([(1, (l, p_arg**l)) for l in range(lmax_out + 1)]) elif (a1 + a2).abs().max() < a1.abs().max() * 1e-10: # p_act = -1 self.irreps_out = o3.Irreps([(1, (l, -p_arg**l)) for l in range(lmax_out + 1)]) else: # p_act = 0 raise ValueError("warning! the parity is violated") self.to_s2 = o3.ToS2Grid(lmax, res, normalization=normalization) self.from_s2 = o3.FromS2Grid(res, lmax_out, normalization=normalization, lmax_in=lmax) self.act = normalize2mom(act) self.random_rot = random_rot
def __init__(self, irreps_node_input, irreps_node_attr, irreps_edge_attr, irreps_node_output, fc_neurons, num_neighbors) -> None: super().__init__() self.irreps_node_input = o3.Irreps(irreps_node_input) self.irreps_node_attr = o3.Irreps(irreps_node_attr) self.irreps_edge_attr = o3.Irreps(irreps_edge_attr) self.irreps_node_output = o3.Irreps(irreps_node_output) self.num_neighbors = num_neighbors self.sc = FullyConnectedTensorProduct(self.irreps_node_input, self.irreps_node_attr, self.irreps_node_output) self.lin1 = FullyConnectedTensorProduct(self.irreps_node_input, self.irreps_node_attr, self.irreps_node_input) irreps_mid = [] instructions = [] for i, (mul, ir_in) in enumerate(self.irreps_node_input): for j, (_, ir_edge) in enumerate(self.irreps_edge_attr): for ir_out in ir_in * ir_edge: if ir_out in self.irreps_node_output or ir_out == o3.Irrep( 0, 1): k = len(irreps_mid) irreps_mid.append((mul, ir_out)) instructions.append((i, j, k, 'uvu', True)) irreps_mid = o3.Irreps(irreps_mid) irreps_mid, p, _ = irreps_mid.sort() instructions = [(i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instructions] tp = TensorProduct( self.irreps_node_input, self.irreps_edge_attr, irreps_mid, instructions, internal_weights=False, shared_weights=False, ) self.fc = FullyConnectedNet(fc_neurons + [tp.weight_numel], torch.nn.functional.silu) self.tp = tp self.lin2 = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, self.irreps_node_output) self.lin3 = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, "0e")
def test_jit_trace(): @compile_mode('trace') class NotTracable(torch.nn.Module): def forward(self, param): if param.shape[0] == 7: return torch.ones(8) else: return torch.randn(8, 3) not_tracable = NotTracable() not_tracable.irreps_in = o3.Irreps("2x0e") not_tracable.irreps_out = o3.Irreps("1x1o") # TorchScript returns some weird exceptions... with pytest.raises(Exception): assert_auto_jitable(not_tracable)
def __init__( self, irreps_in1, irreps_in2, filter_ir_out=None, **kwargs ): irreps_in1 = o3.Irreps(irreps_in1).simplify() irreps_in2 = o3.Irreps(irreps_in2).simplify() if filter_ir_out is not None: filter_ir_out = [o3.Irrep(ir) for ir in filter_ir_out] assert irreps_in1.num_irreps == irreps_in2.num_irreps irreps_in1 = list(irreps_in1) irreps_in2 = list(irreps_in2) i = 0 while i < len(irreps_in1): mul_1, ir_1 = irreps_in1[i] mul_2, ir_2 = irreps_in2[i] if mul_1 < mul_2: irreps_in2[i] = (mul_1, ir_2) irreps_in2.insert(i + 1, (mul_2 - mul_1, ir_2)) if mul_2 < mul_1: irreps_in1[i] = (mul_2, ir_1) irreps_in1.insert(i + 1, (mul_1 - mul_2, ir_1)) i += 1 out = [] instr = [] for i, ((mul, ir_1), (mul_2, ir_2)) in enumerate(zip(irreps_in1, irreps_in2)): assert mul == mul_2 for ir in ir_1 * ir_2: if filter_ir_out is not None and ir not in filter_ir_out: continue i_out = len(out) out.append((mul, ir)) instr += [ (i, i, i_out, 'uuu', False) ] super().__init__(irreps_in1, irreps_in2, out, instr, **kwargs)
def test_normalization(float_tolerance, instance): sqrt_float_tolerance = torch.sqrt(float_tolerance) batch, n = 20, 20 irreps = o3.Irreps("3x0e + 4x1e") m = BatchNorm(irreps, normalization='norm', instance=instance) x = torch.randn(batch, n, irreps.dim).mul(5.0).add(10.0) x = m(x) a = x[..., :3] # [batch, space, mul] assert a.mean([0, 1]).abs().max() < float_tolerance assert a.pow(2).mean([0, 1]).sub(1).abs().max() < sqrt_float_tolerance a = x[..., 3:].reshape(batch, n, 4, 3) # [batch, space, mul, repr] assert a.pow(2).sum(3).mean([0, 1 ]).sub(1).abs().max() < sqrt_float_tolerance m = BatchNorm(irreps, normalization='component', instance=instance) x = torch.randn(batch, n, irreps.dim).mul(5.0).add(10.0) x = m(x) a = x[..., :3] # [batch, space, mul] assert a.mean([0, 1]).abs().max() < float_tolerance assert a.pow(2).mean([0, 1]).sub(1).abs().max() < sqrt_float_tolerance a = x[..., 3:].reshape(batch, n, 4, 3) # [batch, space, mul, repr] assert a.pow(2).mean(3).mean([0, 1 ]).sub(1).abs().max() < sqrt_float_tolerance
def test_weird_irreps(): # string input o3.spherical_harmonics("0e + 1o", torch.randn(1, 3), False) # Weird multipliciteis irreps = o3.Irreps("1x0e + 4x1o + 3x2e") out = o3.spherical_harmonics(irreps, torch.randn(7, 3), True) assert out.shape[-1] == irreps.dim # Bad parity with pytest.raises(ValueError): # L = 1 shouldn't be even for a vector input o3.SphericalHarmonics( irreps_out="1x0e + 4x1e + 3x2e", normalize=True, normalization='integral', irreps_in="1o", ) # Good parity but psuedovector input _ = o3.SphericalHarmonics(irreps_in="1e", irreps_out="1x0e + 4x1e + 3x2e", normalize=True) # Invalid input with pytest.raises(ValueError): _ = o3.SphericalHarmonics( irreps_in="1e + 3o", # invalid irreps_out="1x0e + 4x1e + 3x2e", normalize=True)
def __init__( self, muls, sh_lmax, num_layers, max_radius, num_basis, fc_neurons, num_neighbors, num_nodes, atomref=None, ) -> None: super().__init__() self.sh_lmax = sh_lmax self.max_radius = max_radius self.num_basis = num_basis self.num_nodes = num_nodes self.register_buffer('atomref', atomref) irreps_node_hidden = o3.Irreps([(mul, (l, p)) for l, mul in enumerate(muls) for p in [-1, 1]]) self.mp = MessagePassing( irreps_node_input="0e", irreps_node_hidden=irreps_node_hidden, irreps_node_output="0e + 0o", irreps_node_attr="5x0e", irreps_edge_attr=o3.Irreps.spherical_harmonics(sh_lmax), layers=num_layers, fc_neurons=[self.num_basis] + fc_neurons, num_neighbors=num_neighbors, )
def test_module(normalization, normalize): l = o3.Irreps("0e + 1o + 3o") sp = o3.SphericalHarmonics(l, normalize, normalization) sp_jit = assert_auto_jitable(sp) xyz = torch.randn(11, 3) assert torch.allclose( sp_jit(xyz), o3.spherical_harmonics(l, xyz, normalize, normalization)) assert_equivariant(sp)
def __init__(self) -> None: super().__init__() self.irreps_sh = o3.Irreps.spherical_harmonics(3) irreps_mid = o3.Irreps("64x0e + 24x1e + 24x1o + 16x2e + 16x2o") irreps_out = o3.Irreps("0o + 6x0e") self.tp1 = FullyConnectedTensorProduct( irreps_in1=self.irreps_sh, irreps_in2=self.irreps_sh, irreps_out=irreps_mid, ) self.tp2 = FullyConnectedTensorProduct( irreps_in1=irreps_mid, irreps_in2=self.irreps_sh, irreps_out=irreps_out, ) self.irreps_out = self.tp2.irreps_out
def forward(self, data: Union[Data, Dict[str, torch.Tensor]]) -> torch.Tensor: """evaluate the network Parameters ---------- data : `torch_geometric.data.Data` or dict data object containing - ``pos`` the position of the nodes (atoms) - ``x`` the input features of the nodes, optional - ``z`` the attributes of the nodes, for instance the atom type, optional - ``batch`` the graph to which the node belong, optional """ if 'batch' in data: batch = data['batch'] else: batch = data['pos'].new_zeros(data['pos'].shape[0], dtype=torch.long) edge_index = radius_graph(data['pos'], self.max_radius, batch) edge_src = edge_index[0] edge_dst = edge_index[1] edge_vec = data['pos'][edge_src] - data['pos'][edge_dst] edge_sh = o3.spherical_harmonics(self.irreps_edge_attr, edge_vec, True, normalization='component') edge_length = edge_vec.norm(dim=1) edge_length_embedded = soft_one_hot_linspace( x=edge_length, start=0.0, end=self.max_radius, number=self.number_of_basis, basis='gaussian', cutoff=False).mul(self.number_of_basis**0.5) edge_attr = smooth_cutoff( edge_length / self.max_radius)[:, None] * edge_sh if self.input_has_node_in and 'x' in data: assert self.irreps_in is not None x = data['x'] else: assert self.irreps_in is None x = data['pos'].new_ones((data['pos'].shape[0], 1)) if self.input_has_node_attr and 'z' in data: z = data['z'] else: assert self.irreps_node_attr == o3.Irreps("0e") z = data['pos'].new_ones((data['pos'].shape[0], 1)) for lay in self.layers: x = lay(x, z, edge_src, edge_dst, edge_attr, edge_length_embedded) if self.reduce_output: return scatter(x, batch, dim=0).div(self.num_nodes**0.5) else: return x
def test_arithmetic(): assert 3 * o3.Irrep("6o") == o3.Irreps("3x6o") products = list(o3.Irrep("1o") * o3.Irrep("2e")) assert products == [o3.Irrep("1o"), o3.Irrep("2o"), o3.Irrep("3o")] assert o3.Irrep("4o") + o3.Irrep("7e") == o3.Irreps("4o + 7e") assert 2 * o3.Irreps("2x2e + 4x1o") == o3.Irreps( "2x2e + 4x1o + 2x2e + 4x1o") assert o3.Irreps("2x2e + 4x1o") * 2 == o3.Irreps( "2x2e + 4x1o + 2x2e + 4x1o") assert o3.Irreps("1o + 4o") + o3.Irreps("1o + 7e") == o3.Irreps( "1o + 4o + 1o + 7e")
def __init__(self, irreps_in, irreps_out, irreps_sh, dim_key): super().__init__() self.irreps_in = irreps_in.simplify() self.irreps_out = irreps_out.simplify() self.irreps_sh = irreps_sh.simplify() # self.si = Linear(self.irreps_in, self.irreps_out, internal_weights=True, shared_weights=True) self.si = FullyConnectedTensorProduct(self.irreps_in, o3.Irreps("5x0e"), self.irreps_out) # self.lin1 = Linear(self.irreps_in, self.irreps_in, internal_weights=True, shared_weights=True) self.lin1 = FullyConnectedTensorProduct(self.irreps_in, o3.Irreps("5x0e"), self.irreps_in) instr = [] irreps = [] for i_1, (mul_1, (l_1, p_1)) in enumerate(self.irreps_in): for i_2, (_, (l_2, p_2)) in enumerate(self.irreps_sh): for l_out in range(abs(l_1 - l_2), l_1 + l_2 + 1): p_out = p_1 * p_2 if (l_out, p_out) in [(l, p) for _, (l, p) in self.irreps_out]: r = (mul_1, l_out, p_out) if r in irreps: i_out = irreps.index(r) else: i_out = len(irreps) irreps.append(r) instr += [(i_1, i_2, i_out, 'uvu', True)] irreps = o3.Irreps(irreps) self.tp = TensorProduct(self.irreps_in, self.irreps_sh, irreps, instr, internal_weights=False, shared_weights=False) self.tp_weight = torch.nn.Parameter( torch.randn(dim_key, self.tp.weight_numel)) # self.lin2 = Linear(irreps, self.irreps_out, internal_weights=True, shared_weights=True) self.lin2 = FullyConnectedTensorProduct(irreps, o3.Irreps("5x0e"), self.irreps_out)
def __init__(self, *irreps_outs): super().__init__() self.irreps_outs = tuple( o3.Irreps(irreps).simplify() for irreps in irreps_outs) irreps_in = sum(self.irreps_outs, o3.Irreps([])) i = 0 instructions = [] for irreps_out in self.irreps_outs: instructions += [tuple(range(i, i + len(irreps_out)))] i += len(irreps_out) assert len(irreps_in) == i, (len(irreps_in), i) irreps_in, p, _ = irreps_in.sort() instructions = [tuple(p[i] for i in x) for x in instructions] self.cut = Extract(irreps_in, self.irreps_outs, instructions) self.irreps_in = irreps_in.simplify()
def __init__(self, irreps_in, squared: bool = False): super().__init__() irreps_in = o3.Irreps(irreps_in).simplify() irreps_out = o3.Irreps([(mul, "0e") for mul, _ in irreps_in]) instr = [(i, i, i, 'uuu', False, ir.dim) for i, (mul, ir) in enumerate(irreps_in)] self.tp = o3.TensorProduct(irreps_in, irreps_in, irreps_out, instr, normalization='component') self.irreps_in = irreps_in self.irreps_out = irreps_out.simplify() self.squared = squared