def WeightBalancedIrreps(irreps_in1_scalar, irreps_in2, sh = True): """ Determines an irreps_in1 type of order irreps_in2.lmax that when used in a tensor product irreps_in1 x irreps_in2 -> irreps_in1 would have the same number of weights as for a standard linear layer, e.g. a tensor product irreps_in1_scalar x "1x0e" -> irreps_in1_scaler """ n = 1 lmax = irreps_in2.lmax irreps_in1 = (Irreps.spherical_harmonics(lmax) * n).sort().irreps.simplify() if sh else BalancedIrreps(lmax, n) weight_numel1 = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_in1).weight_numel weight_numel_scalar = FullyConnectedTensorProduct(irreps_in1_scalar, Irreps("1x0e"), irreps_in1_scalar).weight_numel while weight_numel1 < weight_numel_scalar: # TODO: somewhat suboptimal implementation... n += 1 irreps_in1 = (Irreps.spherical_harmonics(lmax) * n).sort().irreps.simplify() if sh else BalancedIrreps(lmax, n) weight_numel1 = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_in1).weight_numel print('Determined irrep type:', irreps_in1) return Irreps(irreps_in1)
def test_equivariance(lmax, res_b, res_a): m = FromS2Grid((res_b, res_a), lmax) k = ToS2Grid(lmax, (res_b, res_a)) def f(x): y = k(x) y = y.exp() return m(y) f.irreps_in = f.irreps_out = Irreps.spherical_harmonics(lmax) assert_equivariant(f)
def __init__( self, num_atoms, # not used bond_feat_dim, # not used num_targets, # not used in_features=9, out_features=1, hidden_features=256, N=7, dim=3, lmax_h=2, lmax_pos=2, update_pos=False, recurrent=True, regress_forces=False, use_pbc=True, otf_graph=False ): super(SEGNNModel, self).__init__() self.in_features = in_features self.out_features = out_features self.hidden_features = hidden_features self.N = N self.regress_forces = regress_forces self.otf_graph = otf_graph self.use_pbc = use_pbc self.update_pos = update_pos self.recurrent = recurrent self.dim = dim self.lmax_h = lmax_h self.lmax_pos = lmax_pos # Irreps for the node features node_in_irreps_scalar = Irreps("{0}x0e".format(self.in_features)) # This is the type of the input #node_hidden_irreps = BalancedIrreps(self.lmax_h, self.hidden_features) # This is the type on the hidden reps node_hidden_irreps_scalar = Irreps("{0}x0e".format(self.hidden_features)) # For the output layers node_out_irreps_scalar = Irreps("{0}x0e".format(self.out_features)) # This is the type on the output # Irreps for the edge and node attributes attr_irreps = Irreps.spherical_harmonics(self.lmax_pos) self.attr_irreps = attr_irreps node_hidden_irreps = WeightBalancedIrreps(node_hidden_irreps_scalar, attr_irreps, False) # True: copies of sh # Network for computing the node attributes self.node_attribute_net = NodeAttributeNetwork() # The embedding layer (acts point-wise, no orientation information so only use trivial/scalar irreps) self.embedding_layer_1 = O3TensorProductSwishGate(node_in_irreps_scalar, # in node_hidden_irreps, # out attr_irreps) # steerable attribute self.embedding_layer_2 = O3TensorProductSwishGate(node_hidden_irreps, # in node_hidden_irreps, # out attr_irreps) # steerable attribute self.embedding_layer_3 = O3TensorProduct(node_hidden_irreps, # in node_hidden_irreps, # out attr_irreps) # steerable attribute # The main layers self.layers = [] for i in range(self.N): self.layers.append(SEGNN(node_hidden_irreps, # in node_hidden_irreps, # hidden node_hidden_irreps, # out attr_irreps, # steerable attribute update_pos=self.update_pos, recurrent=self.recurrent)) self.layers = nn.ModuleList(self.layers) # The output network (again via point-wise operation via scalar irreps) self.head_pre_pool_layer_1 = O3TensorProductSwishGate(node_hidden_irreps, # in node_hidden_irreps_scalar, # out attr_irreps) # steerable attribute self.head_pre_pool_layer_2 = O3TensorProduct(node_hidden_irreps_scalar, # in node_hidden_irreps_scalar) # out self.head_post_pool_layer_1 = O3TensorProductSwishGate(node_hidden_irreps_scalar, # in node_hidden_irreps_scalar) # out self.head_post_pool_layer_2 = O3TensorProduct(node_hidden_irreps_scalar, # in node_out_irreps_scalar) # out # read atom map atom_map = torch.zeros(101, 9) for i in range(101): atom_map[i] = torch.tensor(CONTINUOUS_EMBEDDINGS[i]) # normalize along each dimension atom_map[0] = np.nan atom_map_notnan = atom_map[atom_map[:, 0] == atom_map[:, 0]] atom_map_min = torch.min(atom_map_notnan, dim=0)[0] atom_map_max = torch.max(atom_map_notnan, dim=0)[0] atom_map_gap = atom_map_max - atom_map_min # squash to [0,1] atom_map = (atom_map - atom_map_min.view(1, -1)) / atom_map_gap.view(1, -1) self.atom_map = torch.nn.Parameter(atom_map, requires_grad=False) # read atom radii atom_radii = torch.zeros(101) for i in range(101): atom_radii[i] = ATOMIC_RADII[i] atom_radii = atom_radii / 100 self.atom_radii = nn.Parameter(atom_radii, requires_grad=False)
def __init__( self, num_atoms, # not used bond_feat_dim, # not used num_targets, # not used in_features=9, out_features=1, hidden_features=256, N=7, dim=3, lmax_h=2, lmax_pos=2, update_pos=False, recurrent=True, regress_forces=False, use_pbc=True, otf_graph=False ): super(SEGNNModel, self).__init__() self.in_features = in_features self.out_features = out_features self.hidden_features = hidden_features self.N = N self.regress_forces = regress_forces self.otf_graph = otf_graph self.use_pbc = use_pbc self.update_pos = update_pos self.recurrent = recurrent self.dim = dim self.lmax_h = lmax_h self.lmax_pos = lmax_pos # The representations used in the model self.irreps_in = Irreps("{0}x0e".format(self.in_features)) self.irreps_hidden = BalancedIrreps(self.lmax_h, self.hidden_features) self.irreps_hidden_scalar = Irreps("{0}x0e".format(self.hidden_features)) self.irreps_out = Irreps("{0}x0e".format(self.out_features)) self.irreps_rel_pos = Irreps.spherical_harmonics(self.lmax_pos) # The embedding layer (acts point-wise, no orientation information so only use trivial/scalar irreps) self.embedding = nn.Sequential(O3LinearSwishGate(self.irreps_in, self.irreps_hidden_scalar), O3Linear(self.irreps_hidden_scalar, self.irreps_hidden_scalar)) # The intermediate layers self.layers = [] # The first layer changes from scalar irreps to irreps of some max order (lmax_h) self.layers.append(SEGNN(self.irreps_hidden_scalar, self.irreps_hidden, self.irreps_rel_pos, self.irreps_hidden, update_pos=self.update_pos, recurrent=False)) # Subsequent layers act on the irreps of some max order (lmax_h) for i in range(self.N - 2): self.layers.append(SEGNN(self.irreps_hidden, self.irreps_hidden, self.irreps_rel_pos, self.irreps_hidden, update_pos=self.update_pos, recurrent=self.recurrent)) # The last layer of the SEGNN block converts back to scalar irreps self.layers.append( SEGNN(self.irreps_hidden, self.irreps_hidden_scalar, self.irreps_rel_pos, self.irreps_hidden_scalar, update_pos=self.update_pos, recurrent=False)) # To ModuleList self.layers = nn.ModuleList(self.layers) # The output network (again via point-wise operation via scalar irreps) self.head_pre_pool = nn.Sequential(O3LinearSwishGate(self.irreps_hidden_scalar, self.irreps_hidden_scalar), O3Linear(self.irreps_hidden_scalar, self.irreps_hidden_scalar)) self.head_post_pool = nn.Sequential(O3LinearSwishGate(self.irreps_hidden_scalar, self.irreps_hidden_scalar), O3Linear(self.irreps_hidden_scalar, self.irreps_out)) # read atom map atom_map = torch.zeros(101, 9) for i in range(101): atom_map[i] = torch.tensor(CONTINUOUS_EMBEDDINGS[i]) # normalize along each dimension atom_map[0] = np.nan atom_map_notnan = atom_map[atom_map[:, 0] == atom_map[:, 0]] atom_map_min = torch.min(atom_map_notnan, dim=0)[0] atom_map_max = torch.max(atom_map_notnan, dim=0)[0] atom_map_gap = atom_map_max - atom_map_min # squash to [0,1] atom_map = (atom_map - atom_map_min.view(1, -1)) / atom_map_gap.view(1, -1) self.atom_map = torch.nn.Parameter(atom_map, requires_grad=False)