Пример #1
0
    def __init__(self) -> None:
        super().__init__()
        self.num_neighbors = 3.8  # typical number of neighbors
        self.irreps_sh = o3.Irreps.spherical_harmonics(3)

        irreps = self.irreps_sh

        # First layer with gate
        gate = Gate(
            "16x0e + 16x0o",
            [torch.relu, torch.abs],  # scalar
            "8x0e + 8x0o + 8x0e + 8x0o",
            [torch.relu, torch.tanh, torch.relu, torch.tanh
             ],  # gates (scalars)
            "16x1o + 16x1e"  # gated tensors, num_irreps has to match with gates
        )
        self.conv = Convolution(irreps, self.irreps_sh, gate.irreps_in,
                                self.num_neighbors)
        self.gate = gate
        irreps = self.gate.irreps_out

        # Final layer
        self.final = Convolution(irreps, self.irreps_sh, "0o + 6x0e",
                                 self.num_neighbors)
        self.irreps_out = self.final.irreps_out
Пример #2
0
def test_gate():
    irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated = Irreps(
        "16x0o"), [torch.tanh], Irreps("32x0o"), [torch.tanh
                                                  ], Irreps("16x1e+16x1o")

    sc = _Sortcut(irreps_scalars, irreps_gates)
    assert_auto_jitable(sc)

    g = Gate(irreps_scalars, act_scalars, irreps_gates, act_gates,
             irreps_gated)
    assert_equivariant(g)
    assert_auto_jitable(g)
    assert_normalized(g)
Пример #3
0
    def __init__(self, irreps_in1, irreps_out, irreps_in2 = None) -> None:
        # For the gate the output of the linear needs to have an extra number of scalar irreps equal to the amount of
        # non scalar irreps:
        # The first type is assumed to be scalar and passed through the activation
        irreps_g_scalars = Irreps(str(irreps_out[0]))
        # The remaining types are gated
        irreps_g_gate = Irreps("{}x0e".format(irreps_out.num_irreps - irreps_g_scalars.num_irreps))
        irreps_g_gated = Irreps(str(irreps_out[1:]))
        # So the gate needs the following irrep as input, this is the output irrep of the tensor product
        irreps_g = (irreps_g_scalars + irreps_g_gate + irreps_g_gated).simplify()

        # Build the layers
        super(O3TensorProductSwishGate, self).__init__(irreps_in1, irreps_g, irreps_in2)
        if irreps_g_gated.num_irreps > 0:
            self.gate = Gate(irreps_g_scalars, [Swish()], irreps_g_gate, [torch.sigmoid], irreps_g_gated)
        else:
            self.gate = Swish()
Пример #4
0
def make_gated_block(irreps_in, muls, irreps_sh):
    """
    Make a Gate assuming many things
    """
    irreps_available = [(l, p_in * p_sh)
                        for _, (l_in, p_in) in irreps_in.simplify()
                        for _, (l_sh, p_sh) in irreps_sh
                        for l in range(abs(l_in - l_sh), l_in + l_sh + 1)]

    scalars = o3.Irreps([(muls[0], 0, p) for p in (1, -1)
                         if (0, p) in irreps_available])
    act_scalars = [swish if p == 1 else torch.tanh for _, (_, p) in scalars]

    nonscalars = o3.Irreps([(muls[l], l, p * (-1)**l)
                            for l in range(1, len(muls)) for p in (1, -1)
                            if (l, p * (-1)**l) in irreps_available])
    if (0, +1) in irreps_available:
        gates = o3.Irreps([(nonscalars.num_irreps, 0, +1)])
        act_gates = [torch.sigmoid]
    else:
        gates = o3.Irreps([(nonscalars.num_irreps, 0, -1)])
        act_gates = [torch.tanh]

    return Gate(scalars, act_scalars, gates, act_gates, nonscalars)
Пример #5
0
    def __init__(
        self,
        irreps_in,
        irreps_hidden,
        irreps_out,
        irreps_node_attr,
        irreps_edge_attr,
        layers,
        max_radius,
        number_of_basis,
        radial_layers,
        radial_neurons,
        num_neighbors,
        num_nodes,
        reduce_output=True,
    ) -> None:
        super().__init__()
        self.max_radius = max_radius
        self.number_of_basis = number_of_basis
        self.num_neighbors = num_neighbors
        self.num_nodes = num_nodes
        self.reduce_output = reduce_output

        self.irreps_in = o3.Irreps(
            irreps_in) if irreps_in is not None else None
        self.irreps_hidden = o3.Irreps(irreps_hidden)
        self.irreps_out = o3.Irreps(irreps_out)
        self.irreps_node_attr = o3.Irreps(
            irreps_node_attr) if irreps_node_attr is not None else o3.Irreps(
                "0e")
        self.irreps_edge_attr = o3.Irreps(irreps_edge_attr)

        self.input_has_node_in = (irreps_in is not None)
        self.input_has_node_attr = (irreps_node_attr is not None)

        self.ext_z = ExtractIr(self.irreps_node_attr, '0e')
        number_of_edge_features = number_of_basis + 2 * self.irreps_node_attr.count(
            '0e')

        irreps = self.irreps_in if self.irreps_in is not None else o3.Irreps(
            "0e")

        act = {
            1: torch.nn.functional.silu,
            -1: torch.tanh,
        }
        act_gates = {
            1: torch.sigmoid,
            -1: torch.tanh,
        }

        self.layers = torch.nn.ModuleList()

        for _ in range(layers):
            irreps_scalars = o3.Irreps([
                (mul, ir) for mul, ir in self.irreps_hidden if ir.l == 0
                and tp_path_exists(irreps, self.irreps_edge_attr, ir)
            ])
            irreps_gated = o3.Irreps([
                (mul, ir) for mul, ir in self.irreps_hidden if ir.l > 0
                and tp_path_exists(irreps, self.irreps_edge_attr, ir)
            ])
            ir = "0e" if tp_path_exists(irreps, self.irreps_edge_attr,
                                        "0e") else "0o"
            irreps_gates = o3.Irreps([(mul, ir) for mul, _ in irreps_gated])

            gate = Gate(
                irreps_scalars,
                [act[ir.p] for _, ir in irreps_scalars],  # scalar
                irreps_gates,
                [act_gates[ir.p] for _, ir in irreps_gates],  # gates (scalars)
                irreps_gated  # gated tensors
            )
            conv = Convolution(irreps, self.irreps_node_attr,
                               self.irreps_edge_attr, gate.irreps_in,
                               number_of_edge_features, radial_layers,
                               radial_neurons, num_neighbors)
            irreps = gate.irreps_out
            self.layers.append(Compose(conv, gate))

        self.layers.append(
            Convolution(irreps, self.irreps_node_attr, self.irreps_edge_attr,
                        self.irreps_out, number_of_edge_features,
                        radial_layers, radial_neurons, num_neighbors))
Пример #6
0
    def __init__(
        self,
        irreps_node_input,
        irreps_node_hidden,
        irreps_node_output,
        irreps_node_attr,
        irreps_edge_attr,
        layers,
        fc_neurons,
        num_neighbors,
    ) -> None:
        super().__init__()
        self.num_neighbors = num_neighbors

        self.irreps_node_input = o3.Irreps(irreps_node_input)
        self.irreps_node_hidden = o3.Irreps(irreps_node_hidden)
        self.irreps_node_output = o3.Irreps(irreps_node_output)
        self.irreps_node_attr = o3.Irreps(irreps_node_attr)
        self.irreps_edge_attr = o3.Irreps(irreps_edge_attr)

        irreps_node = self.irreps_node_input

        act = {
            1: torch.nn.functional.silu,
            -1: torch.tanh,
        }
        act_gates = {
            1: torch.sigmoid,
            -1: torch.tanh,
        }

        self.layers = torch.nn.ModuleList()

        for _ in range(layers):
            irreps_scalars = o3.Irreps([
                (mul, ir) for mul, ir in self.irreps_node_hidden if ir.l == 0
                and tp_path_exists(irreps_node, self.irreps_edge_attr, ir)
            ]).simplify()
            irreps_gated = o3.Irreps([
                (mul, ir) for mul, ir in self.irreps_node_hidden if ir.l > 0
                and tp_path_exists(irreps_node, self.irreps_edge_attr, ir)
            ])
            ir = "0e" if tp_path_exists(irreps_node, self.irreps_edge_attr,
                                        "0e") else "0o"
            irreps_gates = o3.Irreps([(mul, ir)
                                      for mul, _ in irreps_gated]).simplify()

            gate = Gate(
                irreps_scalars,
                [act[ir.p] for _, ir in irreps_scalars],  # scalar
                irreps_gates,
                [act_gates[ir.p] for _, ir in irreps_gates],  # gates (scalars)
                irreps_gated  # gated tensors
            )
            conv = Convolution(irreps_node, self.irreps_node_attr,
                               self.irreps_edge_attr, gate.irreps_in,
                               fc_neurons, num_neighbors)
            irreps_node = gate.irreps_out
            self.layers.append(Compose(conv, gate))

        self.layers.append(
            Convolution(irreps_node, self.irreps_node_attr,
                        self.irreps_edge_attr, self.irreps_node_output,
                        fc_neurons, num_neighbors))