def test_Embedding2D(): bond = layers.Input(name='bond', shape=(1, ), dtype='int32') sbond = Squeeze()(bond) embedding = Embedding2D(3, 5) o = embedding(sbond) assert o._keras_shape == (None, 5, 5) model = GraphModel([bond], o) x1 = np.array([1, 1, 2, 2, 0]) out = model.predict_on_batch([x1]) assert_allclose(out[0], out[1]) assert_allclose(out[2], out[3]) assert not (out[0] == out[-1]).all()
def test_set2set(): atom = layers.Input(name='atom', shape=(5, ), dtype='float32') node_graph_indices = layers.Input(name='node_graph_indices', shape=(1, ), dtype='int32') snode = Squeeze()(node_graph_indices) reduce_layer = Set2Set() o = reduce_layer([atom, snode]) assert o._keras_shape == (None, 10) model = GraphModel([atom, node_graph_indices], o) x1 = np.random.rand(5, 5) x2 = np.array([0, 0, 0, 1, 1]) out = model.predict_on_batch([x1, x2]) assert out.shape == (2, 10)
def test_ReduceAtomToMol(): atom = layers.Input(name='atom', shape=(5, ), dtype='float32') node_graph_indices = layers.Input(name='node_graph_indices', shape=(1, ), dtype='int32') snode = Squeeze()(node_graph_indices) reduce_layer = ReduceAtomToMol() o = reduce_layer([atom, snode]) assert o._keras_shape == (None, 5) model = GraphModel([atom, node_graph_indices], o) x1 = np.random.rand(5, 5) x2 = np.array([0, 0, 0, 1, 1]) out = model.predict_on_batch([x1, x2]) assert_allclose(x1[:3].sum(0), out[0]) assert_allclose(x1[3:].sum(0), out[1])
def test_GatherMolToAtomOrBond(): global_state = layers.Input(name='global_state', shape=(5, ), dtype='float32') node_graph_indices = layers.Input(name='node_graph_indices', shape=(1, ), dtype='int32') snode = Squeeze()(node_graph_indices) layer = GatherMolToAtomOrBond() o = layer([global_state, snode]) assert o._keras_shape == (None, 5) model = GraphModel([global_state, node_graph_indices], o) x1 = np.random.rand(2, 5) x2 = np.array([0, 0, 0, 1, 1]) out = model.predict_on_batch([x1, x2]) assert_allclose(out, x1[x2])
# atom_contributions = pd.Series(model.coef_.flatten(), index=X.columns) # atom_contributions = atom_contributions.reindex(np.arange(preprocessor.atom_classes)).fillna(0) # Construct input sequences batch_size = 32 train_sequence = GraphSequence(train_inputs, y_train_scaled, batch_size, final_batch=False) valid_sequence = GraphSequence(valid_inputs, y_valid_scaled, batch_size, final_batch=False) # Raw (integer) graph inputs node_graph_indices = Input(shape=(1,), name='node_graph_indices', dtype='int32') atom_types = Input(shape=(1,), name='atom', dtype='int32') distance_rbf = Input(shape=(150,), name='distance_rbf', dtype='float32') connectivity = Input(shape=(2,), name='connectivity', dtype='int32') squeeze = Squeeze() snode_graph_indices = squeeze(node_graph_indices) satom_types = squeeze(atom_types) # Initialize RNN and MessageLayer instances atom_features = 64 # Initialize the atom states atom_state = Embedding( preprocessor.atom_classes, atom_features, name='atom_embedding')(satom_types) atomwise_energy = Embedding( preprocessor.atom_classes, 1, name='atomwise_energy', )(satom_types)
def test_save_and_load_model(get_2d_sequence, tmpdir): preprocessor, sequence = get_2d_sequence node_graph_indices = Input(shape=(1, ), name='node_graph_indices', dtype='int32') atom_types = Input(shape=(1, ), name='atom', dtype='int32') bond_types = Input(shape=(1, ), name='bond', dtype='int32') connectivity = Input(shape=(2, ), name='connectivity', dtype='int32') squeeze = Squeeze() snode_graph_indices = squeeze(node_graph_indices) satom_types = squeeze(atom_types) sbond_types = squeeze(bond_types) atom_features = 5 atom_state = Embedding(preprocessor.atom_classes, atom_features, name='atom_embedding')(satom_types) bond_matrix = Embedding2D(preprocessor.bond_classes, atom_features, name='bond_embedding')(sbond_types) atom_rnn_layer = GRUStep(atom_features) message_layer = MessageLayer(reducer='sum', dropout=0.1) # Perform the message passing for _ in range(2): # Get the message updates to each atom message = message_layer([atom_state, bond_matrix, connectivity]) # Update memory and atom states atom_state = atom_rnn_layer([message, atom_state]) atom_fingerprint = Dense(64, activation='sigmoid')(atom_state) mol_fingerprint = ReduceAtomToMol(reducer='sum')( [atom_fingerprint, snode_graph_indices]) out = Dense(1)(mol_fingerprint) model = GraphModel( [node_graph_indices, atom_types, bond_types, connectivity], [out]) with warnings.catch_warnings(): warnings.simplefilter("ignore") model.compile(optimizer=keras.optimizers.Adam(lr=1E-4), loss='mse') hist = model.fit_generator(sequence, epochs=1) loss = model.evaluate_generator(sequence) _, fname = tempfile.mkstemp('.h5') model.save(fname) with warnings.catch_warnings(): warnings.simplefilter("ignore") model = load_model(fname, custom_objects=custom_layers) loss2 = model.evaluate_generator(sequence) assert_allclose(loss, loss2)