Example #1
0
def test_Embedding2D():

    bond = layers.Input(name='bond', shape=(1, ), dtype='int32')
    sbond = Squeeze()(bond)

    embedding = Embedding2D(3, 5)
    o = embedding(sbond)
    assert o._keras_shape == (None, 5, 5)

    model = GraphModel([bond], o)

    x1 = np.array([1, 1, 2, 2, 0])
    out = model.predict_on_batch([x1])

    assert_allclose(out[0], out[1])
    assert_allclose(out[2], out[3])

    assert not (out[0] == out[-1]).all()
Example #2
0
def test_set2set():
    atom = layers.Input(name='atom', shape=(5, ), dtype='float32')
    node_graph_indices = layers.Input(name='node_graph_indices',
                                      shape=(1, ),
                                      dtype='int32')

    snode = Squeeze()(node_graph_indices)

    reduce_layer = Set2Set()
    o = reduce_layer([atom, snode])
    assert o._keras_shape == (None, 10)

    model = GraphModel([atom, node_graph_indices], o)

    x1 = np.random.rand(5, 5)
    x2 = np.array([0, 0, 0, 1, 1])

    out = model.predict_on_batch([x1, x2])

    assert out.shape == (2, 10)
Example #3
0
def test_ReduceAtomToMol():
    atom = layers.Input(name='atom', shape=(5, ), dtype='float32')
    node_graph_indices = layers.Input(name='node_graph_indices',
                                      shape=(1, ),
                                      dtype='int32')

    snode = Squeeze()(node_graph_indices)

    reduce_layer = ReduceAtomToMol()
    o = reduce_layer([atom, snode])
    assert o._keras_shape == (None, 5)

    model = GraphModel([atom, node_graph_indices], o)

    x1 = np.random.rand(5, 5)
    x2 = np.array([0, 0, 0, 1, 1])

    out = model.predict_on_batch([x1, x2])

    assert_allclose(x1[:3].sum(0), out[0])
    assert_allclose(x1[3:].sum(0), out[1])
Example #4
0
def test_GatherMolToAtomOrBond():
    global_state = layers.Input(name='global_state',
                                shape=(5, ),
                                dtype='float32')
    node_graph_indices = layers.Input(name='node_graph_indices',
                                      shape=(1, ),
                                      dtype='int32')

    snode = Squeeze()(node_graph_indices)

    layer = GatherMolToAtomOrBond()
    o = layer([global_state, snode])
    assert o._keras_shape == (None, 5)

    model = GraphModel([global_state, node_graph_indices], o)

    x1 = np.random.rand(2, 5)
    x2 = np.array([0, 0, 0, 1, 1])

    out = model.predict_on_batch([x1, x2])
    assert_allclose(out, x1[x2])
Example #5
0
# atom_contributions = pd.Series(model.coef_.flatten(), index=X.columns)
# atom_contributions = atom_contributions.reindex(np.arange(preprocessor.atom_classes)).fillna(0)

# Construct input sequences
batch_size = 32
train_sequence = GraphSequence(train_inputs, y_train_scaled, batch_size, final_batch=False)
valid_sequence = GraphSequence(valid_inputs, y_valid_scaled, batch_size, final_batch=False)

# Raw (integer) graph inputs
node_graph_indices = Input(shape=(1,), name='node_graph_indices', dtype='int32')
atom_types = Input(shape=(1,), name='atom', dtype='int32')
distance_rbf = Input(shape=(150,), name='distance_rbf', dtype='float32')
connectivity = Input(shape=(2,), name='connectivity', dtype='int32')

squeeze = Squeeze()

snode_graph_indices = squeeze(node_graph_indices)
satom_types = squeeze(atom_types)

# Initialize RNN and MessageLayer instances
atom_features = 64

# Initialize the atom states
atom_state = Embedding(
    preprocessor.atom_classes,
    atom_features, name='atom_embedding')(satom_types)

atomwise_energy = Embedding(
    preprocessor.atom_classes, 1, name='atomwise_energy',
)(satom_types)
Example #6
0
def test_save_and_load_model(get_2d_sequence, tmpdir):

    preprocessor, sequence = get_2d_sequence

    node_graph_indices = Input(shape=(1, ),
                               name='node_graph_indices',
                               dtype='int32')
    atom_types = Input(shape=(1, ), name='atom', dtype='int32')
    bond_types = Input(shape=(1, ), name='bond', dtype='int32')
    connectivity = Input(shape=(2, ), name='connectivity', dtype='int32')

    squeeze = Squeeze()

    snode_graph_indices = squeeze(node_graph_indices)
    satom_types = squeeze(atom_types)
    sbond_types = squeeze(bond_types)

    atom_features = 5

    atom_state = Embedding(preprocessor.atom_classes,
                           atom_features,
                           name='atom_embedding')(satom_types)

    bond_matrix = Embedding2D(preprocessor.bond_classes,
                              atom_features,
                              name='bond_embedding')(sbond_types)

    atom_rnn_layer = GRUStep(atom_features)
    message_layer = MessageLayer(reducer='sum', dropout=0.1)

    # Perform the message passing
    for _ in range(2):

        # Get the message updates to each atom
        message = message_layer([atom_state, bond_matrix, connectivity])

        # Update memory and atom states
        atom_state = atom_rnn_layer([message, atom_state])

    atom_fingerprint = Dense(64, activation='sigmoid')(atom_state)
    mol_fingerprint = ReduceAtomToMol(reducer='sum')(
        [atom_fingerprint, snode_graph_indices])

    out = Dense(1)(mol_fingerprint)
    model = GraphModel(
        [node_graph_indices, atom_types, bond_types, connectivity], [out])

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")

        model.compile(optimizer=keras.optimizers.Adam(lr=1E-4), loss='mse')
        hist = model.fit_generator(sequence, epochs=1)

    loss = model.evaluate_generator(sequence)

    _, fname = tempfile.mkstemp('.h5')
    model.save(fname)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")

        model = load_model(fname, custom_objects=custom_layers)
        loss2 = model.evaluate_generator(sequence)

    assert_allclose(loss, loss2)