Ejemplo n.º 1
0
def WeightBalancedIrreps(irreps_in1_scalar, irreps_in2, sh = True):
    """
    Determines an irreps_in1 type of order irreps_in2.lmax that when used in a tensor product
    irreps_in1 x irreps_in2 -> irreps_in1
    would have the same number of weights as for a standard linear layer, e.g. a tensor product
    irreps_in1_scalar x "1x0e" -> irreps_in1_scaler
    """
    n = 1
    lmax = irreps_in2.lmax
    irreps_in1 = (Irreps.spherical_harmonics(lmax) * n).sort().irreps.simplify() if sh else BalancedIrreps(lmax, n)
    weight_numel1 = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_in1).weight_numel
    weight_numel_scalar = FullyConnectedTensorProduct(irreps_in1_scalar, Irreps("1x0e"), irreps_in1_scalar).weight_numel
    while weight_numel1 < weight_numel_scalar:  # TODO: somewhat suboptimal implementation...
        n += 1
        irreps_in1 = (Irreps.spherical_harmonics(lmax) * n).sort().irreps.simplify() if sh else BalancedIrreps(lmax, n)
        weight_numel1 = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_in1).weight_numel
    print('Determined irrep type:', irreps_in1)

    return Irreps(irreps_in1)
Ejemplo n.º 2
0
def test_equivariance(lmax, res_b, res_a):
    m = FromS2Grid((res_b, res_a), lmax)
    k = ToS2Grid(lmax, (res_b, res_a))

    def f(x):
        y = k(x)
        y = y.exp()
        return m(y)

    f.irreps_in = f.irreps_out = Irreps.spherical_harmonics(lmax)

    assert_equivariant(f)
Ejemplo n.º 3
0
    def __init__(
        self,
        num_atoms,  # not used
        bond_feat_dim,  # not used
        num_targets,  # not used
        in_features=9,
        out_features=1,
        hidden_features=256,
        N=7,
        dim=3,
        lmax_h=2,
        lmax_pos=2,
        update_pos=False,
        recurrent=True,
        regress_forces=False,
        use_pbc=True,
        otf_graph=False
    ):

        super(SEGNNModel, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.hidden_features = hidden_features
        self.N = N
        self.regress_forces = regress_forces
        self.otf_graph = otf_graph
        self.use_pbc = use_pbc
        self.update_pos = update_pos
        self.recurrent = recurrent
        self.dim = dim
        self.lmax_h = lmax_h
        self.lmax_pos = lmax_pos

        # Irreps for the node features
        node_in_irreps_scalar = Irreps("{0}x0e".format(self.in_features))  # This is the type of the input
        #node_hidden_irreps = BalancedIrreps(self.lmax_h, self.hidden_features)  # This is the type on the hidden reps
        node_hidden_irreps_scalar = Irreps("{0}x0e".format(self.hidden_features))  # For the output layers
        node_out_irreps_scalar = Irreps("{0}x0e".format(self.out_features))  # This is the type on the output

        # Irreps for the edge and node attributes
        attr_irreps = Irreps.spherical_harmonics(self.lmax_pos)
        self.attr_irreps = attr_irreps

        node_hidden_irreps = WeightBalancedIrreps(node_hidden_irreps_scalar, attr_irreps, False)  # True: copies of sh

        # Network for computing the node attributes
        self.node_attribute_net = NodeAttributeNetwork()

        # The embedding layer (acts point-wise, no orientation information so only use trivial/scalar irreps)
        self.embedding_layer_1 = O3TensorProductSwishGate(node_in_irreps_scalar,  # in
                                                          node_hidden_irreps,  # out
                                                          attr_irreps)  # steerable attribute
        self.embedding_layer_2 = O3TensorProductSwishGate(node_hidden_irreps,  # in
                                                          node_hidden_irreps,  # out
                                                          attr_irreps)  # steerable attribute
        self.embedding_layer_3 = O3TensorProduct(node_hidden_irreps,  # in
                                                 node_hidden_irreps,  # out
                                                 attr_irreps)  # steerable attribute

        # The main layers
        self.layers = []
        for i in range(self.N):
            self.layers.append(SEGNN(node_hidden_irreps,  # in
                                     node_hidden_irreps,  # hidden
                                     node_hidden_irreps,  # out
                                     attr_irreps,  # steerable attribute
                                     update_pos=self.update_pos, recurrent=self.recurrent))
        self.layers = nn.ModuleList(self.layers)

        # The output network (again via point-wise operation via scalar irreps)
        self.head_pre_pool_layer_1 = O3TensorProductSwishGate(node_hidden_irreps,  # in
                                                              node_hidden_irreps_scalar,  # out
                                                              attr_irreps)  # steerable attribute
        self.head_pre_pool_layer_2 = O3TensorProduct(node_hidden_irreps_scalar,  # in
                                                     node_hidden_irreps_scalar)  # out
        self.head_post_pool_layer_1 = O3TensorProductSwishGate(node_hidden_irreps_scalar,  # in
                                                               node_hidden_irreps_scalar)  # out
        self.head_post_pool_layer_2 = O3TensorProduct(node_hidden_irreps_scalar,  # in
                                                      node_out_irreps_scalar)  # out


        # read atom map
        atom_map = torch.zeros(101, 9)
        for i in range(101):
            atom_map[i] = torch.tensor(CONTINUOUS_EMBEDDINGS[i])

        # normalize along each dimension
        atom_map[0] = np.nan
        atom_map_notnan = atom_map[atom_map[:, 0] == atom_map[:, 0]]
        atom_map_min = torch.min(atom_map_notnan, dim=0)[0]
        atom_map_max = torch.max(atom_map_notnan, dim=0)[0]
        atom_map_gap = atom_map_max - atom_map_min
        # squash to [0,1]
        atom_map = (atom_map - atom_map_min.view(1, -1)) / atom_map_gap.view(1, -1)
        self.atom_map = torch.nn.Parameter(atom_map, requires_grad=False)

        # read atom radii
        atom_radii = torch.zeros(101)
        for i in range(101):
            atom_radii[i] = ATOMIC_RADII[i]
        atom_radii = atom_radii / 100
        self.atom_radii = nn.Parameter(atom_radii, requires_grad=False)
Ejemplo n.º 4
0
    def __init__(
        self,
        num_atoms,  # not used
        bond_feat_dim,  # not used
        num_targets,  # not used
        in_features=9,
        out_features=1,
        hidden_features=256,
        N=7,
        dim=3,
        lmax_h=2,
        lmax_pos=2,
        update_pos=False,
        recurrent=True,
        regress_forces=False,
        use_pbc=True,
        otf_graph=False
    ):

        super(SEGNNModel, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.hidden_features = hidden_features
        self.N = N
        self.regress_forces = regress_forces
        self.otf_graph = otf_graph
        self.use_pbc = use_pbc
        self.update_pos = update_pos
        self.recurrent = recurrent
        self.dim = dim
        self.lmax_h = lmax_h
        self.lmax_pos = lmax_pos

        # The representations used in the model
        self.irreps_in = Irreps("{0}x0e".format(self.in_features))
        self.irreps_hidden = BalancedIrreps(self.lmax_h, self.hidden_features)
        self.irreps_hidden_scalar = Irreps("{0}x0e".format(self.hidden_features))
        self.irreps_out = Irreps("{0}x0e".format(self.out_features))
        self.irreps_rel_pos = Irreps.spherical_harmonics(self.lmax_pos)

        # The embedding layer (acts point-wise, no orientation information so only use trivial/scalar irreps)
        self.embedding = nn.Sequential(O3LinearSwishGate(self.irreps_in, self.irreps_hidden_scalar),
                                       O3Linear(self.irreps_hidden_scalar, self.irreps_hidden_scalar))
        # The intermediate layers
        self.layers = []
        # The first layer changes from scalar irreps to irreps of some max order (lmax_h)
        self.layers.append(SEGNN(self.irreps_hidden_scalar, self.irreps_hidden, self.irreps_rel_pos, self.irreps_hidden,
                                 update_pos=self.update_pos, recurrent=False))
        # Subsequent layers act on the irreps of some max order (lmax_h)
        for i in range(self.N - 2):
            self.layers.append(SEGNN(self.irreps_hidden, self.irreps_hidden, self.irreps_rel_pos, self.irreps_hidden,
                                     update_pos=self.update_pos, recurrent=self.recurrent))
        # The last layer of the SEGNN block converts back to scalar irreps
        self.layers.append(
            SEGNN(self.irreps_hidden, self.irreps_hidden_scalar, self.irreps_rel_pos, self.irreps_hidden_scalar,
                  update_pos=self.update_pos, recurrent=False))
        # To ModuleList
        self.layers = nn.ModuleList(self.layers)

        # The output network (again via point-wise operation via scalar irreps)
        self.head_pre_pool = nn.Sequential(O3LinearSwishGate(self.irreps_hidden_scalar, self.irreps_hidden_scalar),
                                           O3Linear(self.irreps_hidden_scalar, self.irreps_hidden_scalar))
        self.head_post_pool = nn.Sequential(O3LinearSwishGate(self.irreps_hidden_scalar, self.irreps_hidden_scalar),
                                            O3Linear(self.irreps_hidden_scalar, self.irreps_out))


        # read atom map
        atom_map = torch.zeros(101, 9)
        for i in range(101):
            atom_map[i] = torch.tensor(CONTINUOUS_EMBEDDINGS[i])

        # normalize along each dimension
        atom_map[0] = np.nan
        atom_map_notnan = atom_map[atom_map[:, 0] == atom_map[:, 0]]
        atom_map_min = torch.min(atom_map_notnan, dim=0)[0]
        atom_map_max = torch.max(atom_map_notnan, dim=0)[0]
        atom_map_gap = atom_map_max - atom_map_min
        # squash to [0,1]
        atom_map = (atom_map - atom_map_min.view(1, -1)) / atom_map_gap.view(1, -1)
        self.atom_map = torch.nn.Parameter(atom_map, requires_grad=False)