def test_shape_cfconv(batchsize, n_atom_basis, n_filters, filter_network, atomic_env, distances, neighbors, neighbor_mask, n_atoms): model = CFConv(n_atom_basis, n_filters, n_atom_basis, filter_network) out_shape = [batchsize, n_atoms, n_atom_basis] inputs = [atomic_env, distances, neighbors, neighbor_mask] assert_equal_shape(model, inputs, out_shape)
def __init__( self, n_atom_basis, n_spatial_basis, n_filters, cutoff, cutoff_network=HardCutoff, normalize_filter=False, ): super(SchNetInteraction, self).__init__() # filter block used in interaction block self.filter_network = nn.Sequential( Dense(n_spatial_basis, n_filters, activation=shifted_softplus), Dense(n_filters, n_filters), ) # cutoff layer used in interaction block self.cutoff_network = cutoff_network(cutoff) # interaction block self.cfconv = CFConv( n_atom_basis, n_filters, n_atom_basis, self.filter_network, cutoff_network=self.cutoff_network, activation=shifted_softplus, normalize_filter=normalize_filter, ) # dense layer self.dense = Dense(n_atom_basis, n_atom_basis, bias=True, activation=None)
def __init__(self, n_atom_basis, n_spatial_basis, n_filters, cutoff, cutoff_network=CosineCutoff, normalize_filter=False, n_heads_weights=0, n_heads_conv=0, device=torch.device("cpu"), hyperparams=[0, 0], dropout=0, exp=False): super(SchNetInteraction, self).__init__() #-# add extra dimensions here for the (dimension of attention embeddings)* num_heads self.n_heads_weights = n_heads_weights self.n_heads_conv = n_heads_conv self.device = device if n_heads_weights > 0: n = 1 else: n = 0 # filter block used in interaction block #n_spatial_basis corresponds to the number of gaussian expansions #n_atom_basis corresponds to the dimension of the atomic embedding corresponding to the projection of attention values self.filter_network = nn.Sequential( Dense( n_spatial_basis + n_atom_basis * n, n_filters, activation=shifted_softplus ), #n_atom_basis could be changed to n_attention_heads*attention_dim at a later time Dense(n_filters, n_filters), ) # cutoff layer used in interaction block self.cutoff_network = cutoff_network(cutoff) # interaction block self.cfconv = CFConv(n_atom_basis, n_filters, n_atom_basis, self.filter_network, cutoff_network=self.cutoff_network, activation=shifted_softplus, normalize_filter=normalize_filter, n_heads_weights=self.n_heads_weights, n_heads_conv=self.n_heads_conv, device=self.device, hyperparams=hyperparams, dropout=dropout, exp=False) # dense layer self.dense = Dense(n_atom_basis, n_atom_basis, bias=True, activation=None)