Esempio n. 1
0
def test_Embedding2D():

    bond = layers.Input(name='bond', shape=(1, ), dtype='int32')
    sbond = Squeeze()(bond)

    embedding = Embedding2D(3, 5)
    o = embedding(sbond)
    assert o._keras_shape == (None, 5, 5)

    model = GraphModel([bond], o)

    x1 = np.array([1, 1, 2, 2, 0])
    out = model.predict_on_batch([x1])

    assert_allclose(out[0], out[1])
    assert_allclose(out[2], out[3])

    assert not (out[0] == out[-1]).all()
Esempio n. 2
0
def test_GatherAtomToBond():
    atom = layers.Input(name='atom', shape=(5, ), dtype='float32')
    connectivity = layers.Input(name='connectivity',
                                shape=(2, ),
                                dtype='int32')

    gather_layer = GatherAtomToBond(index=1)
    o = gather_layer([atom, connectivity])
    assert o._keras_shape == (None, 5)

    x1 = np.random.rand(2, 5)
    x3 = np.array([[0, 1], [1, 0]])

    model = GraphModel([atom, connectivity], o)
    out = model.predict_on_batch({'atom': x1, 'connectivity': x3})

    assert_allclose(out[0], x1[1])
    assert_allclose(out[1], x1[0])
Esempio n. 3
0
def test_EdgeNetwork():

    bond = layers.Input(name='bond', shape=(1, ), dtype='int32')
    distance = layers.Input(name='distance', shape=(1, ), dtype='float32')

    en = EdgeNetwork(5, 3)
    o = en([bond, distance])
    assert o._keras_shape == (None, 5, 5)

    model = GraphModel([bond, distance], o)

    x1 = np.array([1, 1, 2, 2, 0])
    x2 = np.array([1., 1., 2., 3., .5])
    out = model.predict_on_batch([x1, x2])

    assert_allclose(out[0], out[1])
    assert (~np.isclose(out[2], out[3])).any()
    assert (~np.isclose(out[0], out[-1])).any()
Esempio n. 4
0
def test_ReduceBondToAtom():
    bond = layers.Input(name='bond', shape=(5, ), dtype='float32')
    connectivity = layers.Input(name='connectivity',
                                shape=(2, ),
                                dtype='int32')

    reduce_layer = ReduceBondToAtom(reducer='max')
    o = reduce_layer([bond, connectivity])
    assert o._keras_shape == (None, 5)

    model = GraphModel([bond, connectivity], o)

    x1 = np.random.rand(5, 5)
    x2 = np.array([[0, 0, 0, 1, 1], [1, 1, 1, 1, 1]]).T

    out = model.predict_on_batch([x1, x2])

    assert_allclose(x1[:3].max(0), out[0])
    assert_allclose(x1[3:].max(0), out[1])
Esempio n. 5
0
def test_set2set():
    atom = layers.Input(name='atom', shape=(5, ), dtype='float32')
    node_graph_indices = layers.Input(name='node_graph_indices',
                                      shape=(1, ),
                                      dtype='int32')

    snode = Squeeze()(node_graph_indices)

    reduce_layer = Set2Set()
    o = reduce_layer([atom, snode])
    assert o._keras_shape == (None, 10)

    model = GraphModel([atom, node_graph_indices], o)

    x1 = np.random.rand(5, 5)
    x2 = np.array([0, 0, 0, 1, 1])

    out = model.predict_on_batch([x1, x2])

    assert out.shape == (2, 10)
Esempio n. 6
0
def test_ReduceAtomToMol():
    atom = layers.Input(name='atom', shape=(5, ), dtype='float32')
    node_graph_indices = layers.Input(name='node_graph_indices',
                                      shape=(1, ),
                                      dtype='int32')

    snode = Squeeze()(node_graph_indices)

    reduce_layer = ReduceAtomToMol()
    o = reduce_layer([atom, snode])
    assert o._keras_shape == (None, 5)

    model = GraphModel([atom, node_graph_indices], o)

    x1 = np.random.rand(5, 5)
    x2 = np.array([0, 0, 0, 1, 1])

    out = model.predict_on_batch([x1, x2])

    assert_allclose(x1[:3].sum(0), out[0])
    assert_allclose(x1[3:].sum(0), out[1])
Esempio n. 7
0
def test_GatherMolToAtomOrBond():
    global_state = layers.Input(name='global_state',
                                shape=(5, ),
                                dtype='float32')
    node_graph_indices = layers.Input(name='node_graph_indices',
                                      shape=(1, ),
                                      dtype='int32')

    snode = Squeeze()(node_graph_indices)

    layer = GatherMolToAtomOrBond()
    o = layer([global_state, snode])
    assert o._keras_shape == (None, 5)

    model = GraphModel([global_state, node_graph_indices], o)

    x1 = np.random.rand(2, 5)
    x2 = np.array([0, 0, 0, 1, 1])

    out = model.predict_on_batch([x1, x2])
    assert_allclose(out, x1[x2])
Esempio n. 8
0
def test_message():
    atom = layers.Input(name='atom', shape=(5, ), dtype='float32')
    bond = layers.Input(name='bond', shape=(5, 5), dtype='float32')
    connectivity = layers.Input(name='connectivity',
                                shape=(2, ),
                                dtype='int32')

    message_layer = MessageLayer()
    o = message_layer([atom, bond, connectivity])
    assert o._keras_shape == (None, 5)

    model = GraphModel([atom, bond, connectivity], o)

    x1 = np.random.rand(2, 5)
    x2 = np.random.rand(2, 5, 5)
    x3 = np.array([[0, 1], [1, 0]])

    out = model.predict_on_batch({'atom': x1, 'bond': x2, 'connectivity': x3})

    assert_allclose(np.vstack([x2[0].dot(x1[1]), x2[1].dot(x1[0])]),
                    out,
                    rtol=1E-5,
                    atol=1E-5)
Esempio n. 9
0
    messages = Dense(atom_features, activation='softplus')(messages)
    messages = Dense(atom_features)(messages)
    atom_state = Add()([atom_state, messages])
    
    return atom_state, bond_state

for _ in range(3):
    atom_state, bond_state = message_block(atom_state, bond_state, connectivity)

atom_state = Dense(atom_features//2, activation='softplus')(atom_state)
atom_state = Dense(1)(atom_state)
atom_state = Add()([atom_state, atomwise_energy])

output = ReduceAtomToMol(reducer='mean')([atom_state, snode_graph_indices])

model = GraphModel([
    node_graph_indices, atom_types, distance_rbf, connectivity], [output])

lr = 1E-4
epochs = 500

model.compile(optimizer=keras.optimizers.Adam(lr=lr, decay=1E-5), loss='mae')
model.summary()

if not os.path.exists(model_name):
    os.makedirs(model_name)
 
with open('{}/schnet_preprocessor.p'.format(model_name), 'wb') as f:
    pickle.dump(preprocessor, f)
    
filepath = model_name + "/best_model.hdf5"
checkpoint = ModelCheckpoint(filepath, save_best_only=True, period=10, verbose=1)
Esempio n. 10
0
    messages = Dense(atom_features)(messages)
    
    atom_state = Add()([original_atom_state, messages])
    
    return atom_state, bond_state

for i in range(num_messages):
    atom_state, bond_state = message_block(atom_state, bond_state, connectivity, i)

bond_state = Dense(1)(bond_state)
bond_state = Add()([bond_state, bond_mean])

symb_inputs = [mol_type, node_graph_indices, bond_graph_indices,
               atom_types, bond_types, connectivity]

model = GraphModel(symb_inputs, [bond_state])

epochs = 500

model.compile(optimizer=keras.optimizers.Adam(lr=lr, decay=decay), loss=masked_mean_absolute_error)
# model.summary()


if not os.path.exists(model_name):
    os.makedirs(model_name)

# Make a backup of the job submission script
shutil.copy(__file__, model_name)
    
filepath = model_name + "/best_model.hdf5"
checkpoint = ModelCheckpoint(filepath, save_best_only=True, period=10, verbose=0)
Esempio n. 11
0
    # Update memory and atom states
    atom_state = atom_rnn_layer([message, atom_state])

# atom_state = BatchNormalization(momentum=0.9)(atom_state)
atom_fingerprint = Dense(1024, activation='sigmoid')(atom_state)
mol_out = ReduceAtomToMol(reducer='sum')(
    [atom_fingerprint, snode_graph_indices])

X = BatchNormalization(momentum=0.9)(mol_out)
X = Dense(512, activation='relu')(X)

X = BatchNormalization(momentum=0.9)(X)
X = Dense(256, activation='relu')(X)
X = Dense(num_output)(X)

model = GraphModel([node_graph_indices, atom_types, bond_types, connectivity],
                   [X])

epochs = 500
lr = 1E-3
decay = lr / epochs

model.compile(optimizer=keras.optimizers.Adam(lr=lr, decay=decay),
              loss=masked_mean_squared_error)

model.summary()

filepath = model_name + "/best_model.hdf5"
checkpoint = ModelCheckpoint(filepath,
                             save_best_only=True,
                             period=10,
                             verbose=1)
Esempio n. 12
0
output = Add()([atom_state, atomwise_shift])

filepath = "best_model.hdf5"

lr = 5E-4
epochs = 1200

if args.restart:
    model = load_model(filepath, custom_objects={'GraphModel': GraphModel, 
                                                 'Squeeze': Squeeze,
                                                 'GatherAtomToBond': GatherAtomToBond,
                                                 'ReduceBondToAtom': ReduceBondToAtom,
                                                 'ReduceAtomToPro': ReduceAtomToPro})
else:
    model = GraphModel([
        atom_index, atom_types, distance_rbf, connectivity, n_pro], [output])

    model.compile(optimizer=keras.optimizers.Adam(lr=lr), loss='mae')

for layer in model.layers:
    layer.trainable = False

model.get_layer(name='dense_2').trainable = True
model.get_layer(name='dense_5').trainable = True
model.get_layer(name='dense_9').trainable = True
model.get_layer(name='dense_12').trainable = True
model.get_layer(name='dense_16').trainable = True
model.get_layer(name='dense_19').trainable = True

model.compile(optimizer=keras.optimizers.Adam(lr=lr), loss='mae')
model.summary()
Esempio n. 13
0
    atom_rnn_layer = GRUStep(atom_features)
    message_layer = MessageLayer(reducer='sum')

    message_steps = 3
    # Perform the message passing
    for _ in range(message_steps):

        # Get the message updates to each atom
        message = message_layer([atom_state, bond_matrix, connectivity])

        # Update memory and atom states
        atom_state = atom_rnn_layer([message, atom_state])

    atom_fingerprint = Dense(64, activation='relu')(atom_state)
    mol_fingerprint = GraphOutput(reducer='sum')(
        [snode_graph_indices, atom_fingerprint])
    mol_fingerprint = BatchNormalization()(mol_fingerprint)

    out = Dense(1)(mol_fingerprint)
    model = GraphModel(
        [node_graph_indices, atom_types, bond_types, connectivity], [out])
    model.compile(optimizer=keras.optimizers.Adam(lr=0.001), loss='mse')
    model.summary()

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        hist = model.fit_generator(train_generator,
                                   validation_data=test_generator,
                                   epochs=50,
                                   verbose=2)
Esempio n. 14
0
def test_save_and_load_model(get_2d_sequence, tmpdir):

    preprocessor, sequence = get_2d_sequence

    node_graph_indices = Input(shape=(1, ),
                               name='node_graph_indices',
                               dtype='int32')
    atom_types = Input(shape=(1, ), name='atom', dtype='int32')
    bond_types = Input(shape=(1, ), name='bond', dtype='int32')
    connectivity = Input(shape=(2, ), name='connectivity', dtype='int32')

    squeeze = Squeeze()

    snode_graph_indices = squeeze(node_graph_indices)
    satom_types = squeeze(atom_types)
    sbond_types = squeeze(bond_types)

    atom_features = 5

    atom_state = Embedding(preprocessor.atom_classes,
                           atom_features,
                           name='atom_embedding')(satom_types)

    bond_matrix = Embedding2D(preprocessor.bond_classes,
                              atom_features,
                              name='bond_embedding')(sbond_types)

    atom_rnn_layer = GRUStep(atom_features)
    message_layer = MessageLayer(reducer='sum', dropout=0.1)

    # Perform the message passing
    for _ in range(2):

        # Get the message updates to each atom
        message = message_layer([atom_state, bond_matrix, connectivity])

        # Update memory and atom states
        atom_state = atom_rnn_layer([message, atom_state])

    atom_fingerprint = Dense(64, activation='sigmoid')(atom_state)
    mol_fingerprint = ReduceAtomToMol(reducer='sum')(
        [atom_fingerprint, snode_graph_indices])

    out = Dense(1)(mol_fingerprint)
    model = GraphModel(
        [node_graph_indices, atom_types, bond_types, connectivity], [out])

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")

        model.compile(optimizer=keras.optimizers.Adam(lr=1E-4), loss='mse')
        hist = model.fit_generator(sequence, epochs=1)

    loss = model.evaluate_generator(sequence)

    _, fname = tempfile.mkstemp('.h5')
    model.save(fname)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")

        model = load_model(fname, custom_objects=custom_layers)
        loss2 = model.evaluate_generator(sequence)

    assert_allclose(loss, loss2)
Esempio n. 15
0
    atom_state = Add()([atom_state, messages])

    return atom_state, bond_state


for _ in range(3):
    atom_state, bond_state = message_block(atom_state, bond_state,
                                           connectivity)

atom_state = Dense(atom_features // 2, activation='softplus')(atom_state)
atom_state = Dense(1)(atom_state)
atom_state = Add()([atom_state, atomwise_energy])

output = ReduceAtomToMol(reducer='sum')([atom_state, snode_graph_indices])

model = GraphModel(
    [node_graph_indices, atom_types, distance_rbf, connectivity], [output])

lr = 5E-4
epochs = 500

model.compile(optimizer=keras.optimizers.Adam(lr=lr), loss='mae')
model.summary()

model_name = 'schnet_edgeupdate_fixed'

if not os.path.exists(model_name):
    os.makedirs(model_name)

filepath = model_name + "/best_model.hdf5"
checkpoint = ModelCheckpoint(filepath,
                             save_best_only=True,
Esempio n. 16
0
filepath = "best_model.hdf5"

lr = 5E-4
epochs = 1200

if args.restart:
    model = load_model(filepath,
                       custom_objects={
                           'GraphModel': GraphModel,
                           'Squeeze': Squeeze,
                           'GatherAtomToBond': GatherAtomToBond,
                           'ReduceBondToAtom': ReduceBondToAtom,
                           'ReduceAtomToPro': ReduceAtomToPro
                       })
else:
    model = GraphModel(
        [atom_index, atom_types, distance_rbf, connectivity, n_pro], [output])

    model.compile(optimizer=keras.optimizers.Adam(lr=lr), loss='mae')

model.summary()

model.save_weights('bestmodel_weights.h5')

checkpoint = ModelCheckpoint(filepath,
                             save_best_only=True,
                             period=10,
                             verbose=1)
csv_logger = CSVLogger('log.csv')


def decay_fn(epoch, learning_rate):