Exemplo n.º 1
0
    atom_rnn_layer = GRUStep(atom_features)
    message_layer = MessageLayer(reducer='sum')

    message_steps = 3
    # Perform the message passing
    for _ in range(message_steps):

        # Get the message updates to each atom
        message = message_layer([atom_state, bond_matrix, connectivity])

        # Update memory and atom states
        atom_state = atom_rnn_layer([message, atom_state])

    atom_fingerprint = Dense(64, activation='relu')(atom_state)
    mol_fingerprint = GraphOutput(reducer='sum')(
        [snode_graph_indices, atom_fingerprint])
    mol_fingerprint = BatchNormalization()(mol_fingerprint)

    out = Dense(1)(mol_fingerprint)
    model = GraphModel(
        [node_graph_indices, atom_types, bond_types, connectivity], [out])
    model.compile(optimizer=keras.optimizers.Adam(lr=0.001), loss='mse')
    model.summary()

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        hist = model.fit_generator(train_generator,
                                   validation_data=test_generator,
                                   epochs=50,
                                   verbose=2)
Exemplo n.º 2
0
    atom_state, bond_state = message_block(atom_state, bond_state, connectivity)

atom_state = Dense(atom_features//2, activation='softplus')(atom_state)
atom_state = Dense(1)(atom_state)
atom_state = Add()([atom_state, atomwise_energy])

output = ReduceAtomToMol(reducer='mean')([atom_state, snode_graph_indices])

model = GraphModel([
    node_graph_indices, atom_types, distance_rbf, connectivity], [output])

lr = 1E-4
epochs = 500

model.compile(optimizer=keras.optimizers.Adam(lr=lr, decay=1E-5), loss='mae')
model.summary()

if not os.path.exists(model_name):
    os.makedirs(model_name)
 
with open('{}/schnet_preprocessor.p'.format(model_name), 'wb') as f:
    pickle.dump(preprocessor, f)
    
filepath = model_name + "/best_model.hdf5"
checkpoint = ModelCheckpoint(filepath, save_best_only=True, period=10, verbose=1)
csv_logger = CSVLogger(model_name + '/log.csv')

hist = model.fit_generator(train_sequence, validation_data=valid_sequence,
                           epochs=epochs, verbose=1, 
                           callbacks=[checkpoint, csv_logger])
Exemplo n.º 3
0
def test_save_and_load_model(get_2d_sequence, tmpdir):

    preprocessor, sequence = get_2d_sequence

    node_graph_indices = Input(shape=(1, ),
                               name='node_graph_indices',
                               dtype='int32')
    atom_types = Input(shape=(1, ), name='atom', dtype='int32')
    bond_types = Input(shape=(1, ), name='bond', dtype='int32')
    connectivity = Input(shape=(2, ), name='connectivity', dtype='int32')

    squeeze = Squeeze()

    snode_graph_indices = squeeze(node_graph_indices)
    satom_types = squeeze(atom_types)
    sbond_types = squeeze(bond_types)

    atom_features = 5

    atom_state = Embedding(preprocessor.atom_classes,
                           atom_features,
                           name='atom_embedding')(satom_types)

    bond_matrix = Embedding2D(preprocessor.bond_classes,
                              atom_features,
                              name='bond_embedding')(sbond_types)

    atom_rnn_layer = GRUStep(atom_features)
    message_layer = MessageLayer(reducer='sum', dropout=0.1)

    # Perform the message passing
    for _ in range(2):

        # Get the message updates to each atom
        message = message_layer([atom_state, bond_matrix, connectivity])

        # Update memory and atom states
        atom_state = atom_rnn_layer([message, atom_state])

    atom_fingerprint = Dense(64, activation='sigmoid')(atom_state)
    mol_fingerprint = ReduceAtomToMol(reducer='sum')(
        [atom_fingerprint, snode_graph_indices])

    out = Dense(1)(mol_fingerprint)
    model = GraphModel(
        [node_graph_indices, atom_types, bond_types, connectivity], [out])

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")

        model.compile(optimizer=keras.optimizers.Adam(lr=1E-4), loss='mse')
        hist = model.fit_generator(sequence, epochs=1)

    loss = model.evaluate_generator(sequence)

    _, fname = tempfile.mkstemp('.h5')
    model.save(fname)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")

        model = load_model(fname, custom_objects=custom_layers)
        loss2 = model.evaluate_generator(sequence)

    assert_allclose(loss, loss2)