Ejemplo n.º 1
0
def test_ReduceAtomToMol():
    atom = layers.Input(name='atom', shape=(5, ), dtype='float32')
    node_graph_indices = layers.Input(name='node_graph_indices',
                                      shape=(1, ),
                                      dtype='int32')

    snode = Squeeze()(node_graph_indices)

    reduce_layer = ReduceAtomToMol()
    o = reduce_layer([atom, snode])
    assert o._keras_shape == (None, 5)

    model = GraphModel([atom, node_graph_indices], o)

    x1 = np.random.rand(5, 5)
    x2 = np.array([0, 0, 0, 1, 1])

    out = model.predict_on_batch([x1, x2])

    assert_allclose(x1[:3].sum(0), out[0])
    assert_allclose(x1[3:].sum(0), out[1])
Ejemplo n.º 2
0
    
    # state transition function
    messages = Dense(atom_features, activation='softplus')(messages)
    messages = Dense(atom_features)(messages)
    atom_state = Add()([atom_state, messages])
    
    return atom_state, bond_state

for _ in range(3):
    atom_state, bond_state = message_block(atom_state, bond_state, connectivity)

atom_state = Dense(atom_features//2, activation='softplus')(atom_state)
atom_state = Dense(1)(atom_state)
atom_state = Add()([atom_state, atomwise_energy])

output = ReduceAtomToMol(reducer='mean')([atom_state, snode_graph_indices])

model = GraphModel([
    node_graph_indices, atom_types, distance_rbf, connectivity], [output])

lr = 1E-4
epochs = 500

model.compile(optimizer=keras.optimizers.Adam(lr=lr, decay=1E-5), loss='mae')
model.summary()

if not os.path.exists(model_name):
    os.makedirs(model_name)
 
with open('{}/schnet_preprocessor.p'.format(model_name), 'wb') as f:
    pickle.dump(preprocessor, f)
Ejemplo n.º 3
0
atom_rnn_layer = GRUStep(atom_features)
message_layer = MessageLayer(reducer='sum')

message_steps = 3
# Perform the message passing
for _ in range(message_steps):

    # Get the message updates to each atom
    message = message_layer([atom_state, bond_matrix, connectivity])

    # Update memory and atom states
    atom_state = atom_rnn_layer([message, atom_state])

# atom_state = BatchNormalization(momentum=0.9)(atom_state)
atom_fingerprint = Dense(1024, activation='sigmoid')(atom_state)
mol_out = ReduceAtomToMol(reducer='sum')(
    [atom_fingerprint, snode_graph_indices])

X = BatchNormalization(momentum=0.9)(mol_out)
X = Dense(512, activation='relu')(X)

X = BatchNormalization(momentum=0.9)(X)
X = Dense(256, activation='relu')(X)
X = Dense(num_output)(X)

model = GraphModel([node_graph_indices, atom_types, bond_types, connectivity],
                   [X])

epochs = 500
lr = 1E-3
decay = lr / epochs
Ejemplo n.º 4
0
def test_save_and_load_model(get_2d_sequence, tmpdir):

    preprocessor, sequence = get_2d_sequence

    node_graph_indices = Input(shape=(1, ),
                               name='node_graph_indices',
                               dtype='int32')
    atom_types = Input(shape=(1, ), name='atom', dtype='int32')
    bond_types = Input(shape=(1, ), name='bond', dtype='int32')
    connectivity = Input(shape=(2, ), name='connectivity', dtype='int32')

    squeeze = Squeeze()

    snode_graph_indices = squeeze(node_graph_indices)
    satom_types = squeeze(atom_types)
    sbond_types = squeeze(bond_types)

    atom_features = 5

    atom_state = Embedding(preprocessor.atom_classes,
                           atom_features,
                           name='atom_embedding')(satom_types)

    bond_matrix = Embedding2D(preprocessor.bond_classes,
                              atom_features,
                              name='bond_embedding')(sbond_types)

    atom_rnn_layer = GRUStep(atom_features)
    message_layer = MessageLayer(reducer='sum', dropout=0.1)

    # Perform the message passing
    for _ in range(2):

        # Get the message updates to each atom
        message = message_layer([atom_state, bond_matrix, connectivity])

        # Update memory and atom states
        atom_state = atom_rnn_layer([message, atom_state])

    atom_fingerprint = Dense(64, activation='sigmoid')(atom_state)
    mol_fingerprint = ReduceAtomToMol(reducer='sum')(
        [atom_fingerprint, snode_graph_indices])

    out = Dense(1)(mol_fingerprint)
    model = GraphModel(
        [node_graph_indices, atom_types, bond_types, connectivity], [out])

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")

        model.compile(optimizer=keras.optimizers.Adam(lr=1E-4), loss='mse')
        hist = model.fit_generator(sequence, epochs=1)

    loss = model.evaluate_generator(sequence)

    _, fname = tempfile.mkstemp('.h5')
    model.save(fname)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")

        model = load_model(fname, custom_objects=custom_layers)
        loss2 = model.evaluate_generator(sequence)

    assert_allclose(loss, loss2)