atom_rnn_layer = GRUStep(atom_features) message_layer = MessageLayer(reducer='sum') message_steps = 3 # Perform the message passing for _ in range(message_steps): # Get the message updates to each atom message = message_layer([atom_state, bond_matrix, connectivity]) # Update memory and atom states atom_state = atom_rnn_layer([message, atom_state]) atom_fingerprint = Dense(64, activation='relu')(atom_state) mol_fingerprint = GraphOutput(reducer='sum')( [snode_graph_indices, atom_fingerprint]) mol_fingerprint = BatchNormalization()(mol_fingerprint) out = Dense(1)(mol_fingerprint) model = GraphModel( [node_graph_indices, atom_types, bond_types, connectivity], [out]) model.compile(optimizer=keras.optimizers.Adam(lr=0.001), loss='mse') model.summary() with warnings.catch_warnings(): warnings.simplefilter("ignore") hist = model.fit_generator(train_generator, validation_data=test_generator, epochs=50, verbose=2)
atom_state, bond_state = message_block(atom_state, bond_state, connectivity) atom_state = Dense(atom_features//2, activation='softplus')(atom_state) atom_state = Dense(1)(atom_state) atom_state = Add()([atom_state, atomwise_energy]) output = ReduceAtomToMol(reducer='mean')([atom_state, snode_graph_indices]) model = GraphModel([ node_graph_indices, atom_types, distance_rbf, connectivity], [output]) lr = 1E-4 epochs = 500 model.compile(optimizer=keras.optimizers.Adam(lr=lr, decay=1E-5), loss='mae') model.summary() if not os.path.exists(model_name): os.makedirs(model_name) with open('{}/schnet_preprocessor.p'.format(model_name), 'wb') as f: pickle.dump(preprocessor, f) filepath = model_name + "/best_model.hdf5" checkpoint = ModelCheckpoint(filepath, save_best_only=True, period=10, verbose=1) csv_logger = CSVLogger(model_name + '/log.csv') hist = model.fit_generator(train_sequence, validation_data=valid_sequence, epochs=epochs, verbose=1, callbacks=[checkpoint, csv_logger])
def test_save_and_load_model(get_2d_sequence, tmpdir): preprocessor, sequence = get_2d_sequence node_graph_indices = Input(shape=(1, ), name='node_graph_indices', dtype='int32') atom_types = Input(shape=(1, ), name='atom', dtype='int32') bond_types = Input(shape=(1, ), name='bond', dtype='int32') connectivity = Input(shape=(2, ), name='connectivity', dtype='int32') squeeze = Squeeze() snode_graph_indices = squeeze(node_graph_indices) satom_types = squeeze(atom_types) sbond_types = squeeze(bond_types) atom_features = 5 atom_state = Embedding(preprocessor.atom_classes, atom_features, name='atom_embedding')(satom_types) bond_matrix = Embedding2D(preprocessor.bond_classes, atom_features, name='bond_embedding')(sbond_types) atom_rnn_layer = GRUStep(atom_features) message_layer = MessageLayer(reducer='sum', dropout=0.1) # Perform the message passing for _ in range(2): # Get the message updates to each atom message = message_layer([atom_state, bond_matrix, connectivity]) # Update memory and atom states atom_state = atom_rnn_layer([message, atom_state]) atom_fingerprint = Dense(64, activation='sigmoid')(atom_state) mol_fingerprint = ReduceAtomToMol(reducer='sum')( [atom_fingerprint, snode_graph_indices]) out = Dense(1)(mol_fingerprint) model = GraphModel( [node_graph_indices, atom_types, bond_types, connectivity], [out]) with warnings.catch_warnings(): warnings.simplefilter("ignore") model.compile(optimizer=keras.optimizers.Adam(lr=1E-4), loss='mse') hist = model.fit_generator(sequence, epochs=1) loss = model.evaluate_generator(sequence) _, fname = tempfile.mkstemp('.h5') model.save(fname) with warnings.catch_warnings(): warnings.simplefilter("ignore") model = load_model(fname, custom_objects=custom_layers) loss2 = model.evaluate_generator(sequence) assert_allclose(loss, loss2)