def test_ReduceAtomToMol(): atom = layers.Input(name='atom', shape=(5, ), dtype='float32') node_graph_indices = layers.Input(name='node_graph_indices', shape=(1, ), dtype='int32') snode = Squeeze()(node_graph_indices) reduce_layer = ReduceAtomToMol() o = reduce_layer([atom, snode]) assert o._keras_shape == (None, 5) model = GraphModel([atom, node_graph_indices], o) x1 = np.random.rand(5, 5) x2 = np.array([0, 0, 0, 1, 1]) out = model.predict_on_batch([x1, x2]) assert_allclose(x1[:3].sum(0), out[0]) assert_allclose(x1[3:].sum(0), out[1])
# state transition function messages = Dense(atom_features, activation='softplus')(messages) messages = Dense(atom_features)(messages) atom_state = Add()([atom_state, messages]) return atom_state, bond_state for _ in range(3): atom_state, bond_state = message_block(atom_state, bond_state, connectivity) atom_state = Dense(atom_features//2, activation='softplus')(atom_state) atom_state = Dense(1)(atom_state) atom_state = Add()([atom_state, atomwise_energy]) output = ReduceAtomToMol(reducer='mean')([atom_state, snode_graph_indices]) model = GraphModel([ node_graph_indices, atom_types, distance_rbf, connectivity], [output]) lr = 1E-4 epochs = 500 model.compile(optimizer=keras.optimizers.Adam(lr=lr, decay=1E-5), loss='mae') model.summary() if not os.path.exists(model_name): os.makedirs(model_name) with open('{}/schnet_preprocessor.p'.format(model_name), 'wb') as f: pickle.dump(preprocessor, f)
atom_rnn_layer = GRUStep(atom_features) message_layer = MessageLayer(reducer='sum') message_steps = 3 # Perform the message passing for _ in range(message_steps): # Get the message updates to each atom message = message_layer([atom_state, bond_matrix, connectivity]) # Update memory and atom states atom_state = atom_rnn_layer([message, atom_state]) # atom_state = BatchNormalization(momentum=0.9)(atom_state) atom_fingerprint = Dense(1024, activation='sigmoid')(atom_state) mol_out = ReduceAtomToMol(reducer='sum')( [atom_fingerprint, snode_graph_indices]) X = BatchNormalization(momentum=0.9)(mol_out) X = Dense(512, activation='relu')(X) X = BatchNormalization(momentum=0.9)(X) X = Dense(256, activation='relu')(X) X = Dense(num_output)(X) model = GraphModel([node_graph_indices, atom_types, bond_types, connectivity], [X]) epochs = 500 lr = 1E-3 decay = lr / epochs
def test_save_and_load_model(get_2d_sequence, tmpdir): preprocessor, sequence = get_2d_sequence node_graph_indices = Input(shape=(1, ), name='node_graph_indices', dtype='int32') atom_types = Input(shape=(1, ), name='atom', dtype='int32') bond_types = Input(shape=(1, ), name='bond', dtype='int32') connectivity = Input(shape=(2, ), name='connectivity', dtype='int32') squeeze = Squeeze() snode_graph_indices = squeeze(node_graph_indices) satom_types = squeeze(atom_types) sbond_types = squeeze(bond_types) atom_features = 5 atom_state = Embedding(preprocessor.atom_classes, atom_features, name='atom_embedding')(satom_types) bond_matrix = Embedding2D(preprocessor.bond_classes, atom_features, name='bond_embedding')(sbond_types) atom_rnn_layer = GRUStep(atom_features) message_layer = MessageLayer(reducer='sum', dropout=0.1) # Perform the message passing for _ in range(2): # Get the message updates to each atom message = message_layer([atom_state, bond_matrix, connectivity]) # Update memory and atom states atom_state = atom_rnn_layer([message, atom_state]) atom_fingerprint = Dense(64, activation='sigmoid')(atom_state) mol_fingerprint = ReduceAtomToMol(reducer='sum')( [atom_fingerprint, snode_graph_indices]) out = Dense(1)(mol_fingerprint) model = GraphModel( [node_graph_indices, atom_types, bond_types, connectivity], [out]) with warnings.catch_warnings(): warnings.simplefilter("ignore") model.compile(optimizer=keras.optimizers.Adam(lr=1E-4), loss='mse') hist = model.fit_generator(sequence, epochs=1) loss = model.evaluate_generator(sequence) _, fname = tempfile.mkstemp('.h5') model.save(fname) with warnings.catch_warnings(): warnings.simplefilter("ignore") model = load_model(fname, custom_objects=custom_layers) loss2 = model.evaluate_generator(sequence) assert_allclose(loss, loss2)