def __init__(self) -> None: super().__init__() self.num_neighbors = 3.8 # typical number of neighbors self.irreps_sh = o3.Irreps.spherical_harmonics(3) irreps = self.irreps_sh # First layer with gate gate = Gate( "16x0e + 16x0o", [torch.relu, torch.abs], # scalar "8x0e + 8x0o + 8x0e + 8x0o", [torch.relu, torch.tanh, torch.relu, torch.tanh ], # gates (scalars) "16x1o + 16x1e" # gated tensors, num_irreps has to match with gates ) self.conv = Convolution(irreps, self.irreps_sh, gate.irreps_in, self.num_neighbors) self.gate = gate irreps = self.gate.irreps_out # Final layer self.final = Convolution(irreps, self.irreps_sh, "0o + 6x0e", self.num_neighbors) self.irreps_out = self.final.irreps_out
def test_gate(): irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated = Irreps( "16x0o"), [torch.tanh], Irreps("32x0o"), [torch.tanh ], Irreps("16x1e+16x1o") sc = _Sortcut(irreps_scalars, irreps_gates) assert_auto_jitable(sc) g = Gate(irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated) assert_equivariant(g) assert_auto_jitable(g) assert_normalized(g)
def __init__(self, irreps_in1, irreps_out, irreps_in2 = None) -> None: # For the gate the output of the linear needs to have an extra number of scalar irreps equal to the amount of # non scalar irreps: # The first type is assumed to be scalar and passed through the activation irreps_g_scalars = Irreps(str(irreps_out[0])) # The remaining types are gated irreps_g_gate = Irreps("{}x0e".format(irreps_out.num_irreps - irreps_g_scalars.num_irreps)) irreps_g_gated = Irreps(str(irreps_out[1:])) # So the gate needs the following irrep as input, this is the output irrep of the tensor product irreps_g = (irreps_g_scalars + irreps_g_gate + irreps_g_gated).simplify() # Build the layers super(O3TensorProductSwishGate, self).__init__(irreps_in1, irreps_g, irreps_in2) if irreps_g_gated.num_irreps > 0: self.gate = Gate(irreps_g_scalars, [Swish()], irreps_g_gate, [torch.sigmoid], irreps_g_gated) else: self.gate = Swish()
def make_gated_block(irreps_in, muls, irreps_sh): """ Make a Gate assuming many things """ irreps_available = [(l, p_in * p_sh) for _, (l_in, p_in) in irreps_in.simplify() for _, (l_sh, p_sh) in irreps_sh for l in range(abs(l_in - l_sh), l_in + l_sh + 1)] scalars = o3.Irreps([(muls[0], 0, p) for p in (1, -1) if (0, p) in irreps_available]) act_scalars = [swish if p == 1 else torch.tanh for _, (_, p) in scalars] nonscalars = o3.Irreps([(muls[l], l, p * (-1)**l) for l in range(1, len(muls)) for p in (1, -1) if (l, p * (-1)**l) in irreps_available]) if (0, +1) in irreps_available: gates = o3.Irreps([(nonscalars.num_irreps, 0, +1)]) act_gates = [torch.sigmoid] else: gates = o3.Irreps([(nonscalars.num_irreps, 0, -1)]) act_gates = [torch.tanh] return Gate(scalars, act_scalars, gates, act_gates, nonscalars)
def __init__( self, irreps_in, irreps_hidden, irreps_out, irreps_node_attr, irreps_edge_attr, layers, max_radius, number_of_basis, radial_layers, radial_neurons, num_neighbors, num_nodes, reduce_output=True, ) -> None: super().__init__() self.max_radius = max_radius self.number_of_basis = number_of_basis self.num_neighbors = num_neighbors self.num_nodes = num_nodes self.reduce_output = reduce_output self.irreps_in = o3.Irreps( irreps_in) if irreps_in is not None else None self.irreps_hidden = o3.Irreps(irreps_hidden) self.irreps_out = o3.Irreps(irreps_out) self.irreps_node_attr = o3.Irreps( irreps_node_attr) if irreps_node_attr is not None else o3.Irreps( "0e") self.irreps_edge_attr = o3.Irreps(irreps_edge_attr) self.input_has_node_in = (irreps_in is not None) self.input_has_node_attr = (irreps_node_attr is not None) self.ext_z = ExtractIr(self.irreps_node_attr, '0e') number_of_edge_features = number_of_basis + 2 * self.irreps_node_attr.count( '0e') irreps = self.irreps_in if self.irreps_in is not None else o3.Irreps( "0e") act = { 1: torch.nn.functional.silu, -1: torch.tanh, } act_gates = { 1: torch.sigmoid, -1: torch.tanh, } self.layers = torch.nn.ModuleList() for _ in range(layers): irreps_scalars = o3.Irreps([ (mul, ir) for mul, ir in self.irreps_hidden if ir.l == 0 and tp_path_exists(irreps, self.irreps_edge_attr, ir) ]) irreps_gated = o3.Irreps([ (mul, ir) for mul, ir in self.irreps_hidden if ir.l > 0 and tp_path_exists(irreps, self.irreps_edge_attr, ir) ]) ir = "0e" if tp_path_exists(irreps, self.irreps_edge_attr, "0e") else "0o" irreps_gates = o3.Irreps([(mul, ir) for mul, _ in irreps_gated]) gate = Gate( irreps_scalars, [act[ir.p] for _, ir in irreps_scalars], # scalar irreps_gates, [act_gates[ir.p] for _, ir in irreps_gates], # gates (scalars) irreps_gated # gated tensors ) conv = Convolution(irreps, self.irreps_node_attr, self.irreps_edge_attr, gate.irreps_in, number_of_edge_features, radial_layers, radial_neurons, num_neighbors) irreps = gate.irreps_out self.layers.append(Compose(conv, gate)) self.layers.append( Convolution(irreps, self.irreps_node_attr, self.irreps_edge_attr, self.irreps_out, number_of_edge_features, radial_layers, radial_neurons, num_neighbors))
def __init__( self, irreps_node_input, irreps_node_hidden, irreps_node_output, irreps_node_attr, irreps_edge_attr, layers, fc_neurons, num_neighbors, ) -> None: super().__init__() self.num_neighbors = num_neighbors self.irreps_node_input = o3.Irreps(irreps_node_input) self.irreps_node_hidden = o3.Irreps(irreps_node_hidden) self.irreps_node_output = o3.Irreps(irreps_node_output) self.irreps_node_attr = o3.Irreps(irreps_node_attr) self.irreps_edge_attr = o3.Irreps(irreps_edge_attr) irreps_node = self.irreps_node_input act = { 1: torch.nn.functional.silu, -1: torch.tanh, } act_gates = { 1: torch.sigmoid, -1: torch.tanh, } self.layers = torch.nn.ModuleList() for _ in range(layers): irreps_scalars = o3.Irreps([ (mul, ir) for mul, ir in self.irreps_node_hidden if ir.l == 0 and tp_path_exists(irreps_node, self.irreps_edge_attr, ir) ]).simplify() irreps_gated = o3.Irreps([ (mul, ir) for mul, ir in self.irreps_node_hidden if ir.l > 0 and tp_path_exists(irreps_node, self.irreps_edge_attr, ir) ]) ir = "0e" if tp_path_exists(irreps_node, self.irreps_edge_attr, "0e") else "0o" irreps_gates = o3.Irreps([(mul, ir) for mul, _ in irreps_gated]).simplify() gate = Gate( irreps_scalars, [act[ir.p] for _, ir in irreps_scalars], # scalar irreps_gates, [act_gates[ir.p] for _, ir in irreps_gates], # gates (scalars) irreps_gated # gated tensors ) conv = Convolution(irreps_node, self.irreps_node_attr, self.irreps_edge_attr, gate.irreps_in, fc_neurons, num_neighbors) irreps_node = gate.irreps_out self.layers.append(Compose(conv, gate)) self.layers.append( Convolution(irreps_node, self.irreps_node_attr, self.irreps_edge_attr, self.irreps_node_output, fc_neurons, num_neighbors))