Exemple #1
0
def message_block(atom_state, bond_state, connectivity):

    source_atom_gather = GatherAtomToBond(1)
    target_atom_gather = GatherAtomToBond(0)

    source_atom = source_atom_gather([atom_state, connectivity])
    target_atom = target_atom_gather([atom_state, connectivity])

    # Edge update network
    bond_state = Concatenate()([source_atom, target_atom, bond_state])
    bond_state = Dense(2*atom_features, activation='softplus')(bond_state)
    bond_state = Dense(atom_features)(bond_state)

    # message function
    bond_state = Dense(atom_features, activation='softplus')(bond_state)
    bond_state = Dense(atom_features, activation='softplus')(bond_state)
    source_atom = Dense(atom_features)(source_atom)    
    messages = Multiply()([source_atom, bond_state])
    messages = ReduceBondToAtom(reducer='sum')([messages, connectivity])
    
    # state transition function
    messages = Dense(atom_features, activation='softplus')(messages)
    messages = Dense(atom_features)(messages)
    atom_state = Add()([atom_state, messages])
    
    return atom_state, bond_state
def message_block(original_atom_state, original_bond_state, connectivity, i):
    
    atom_state = BatchNormalization()(original_atom_state)
    bond_state = BatchNormalization()(original_bond_state)
    
    source_atom_gather = GatherAtomToBond(1)
    target_atom_gather = GatherAtomToBond(0)

    source_atom = source_atom_gather([atom_state, connectivity])
    target_atom = target_atom_gather([atom_state, connectivity])

    # Edge update network
    new_bond_state = Concatenate(name='concat_{}'.format(i))([
        source_atom, target_atom, bond_state])
    new_bond_state = Dense(
        2*atom_features, activation='relu')(new_bond_state)
    new_bond_state = Dense(atom_features)(new_bond_state)

    bond_state = Add()([original_bond_state, new_bond_state])

    # message function
    source_atom = Dense(atom_features)(source_atom)    
    messages = Multiply()([source_atom, bond_state])
    messages = ReduceBondToAtom(reducer='sum')([messages, connectivity])
    
    # state transition function
    messages = Dense(atom_features, activation='relu')(messages)
    messages = Dense(atom_features)(messages)
    
    atom_state = Add()([original_atom_state, messages])
    
    return atom_state, bond_state
Exemple #3
0
def test_GatherAtomToBond():
    atom = layers.Input(name='atom', shape=(5, ), dtype='float32')
    connectivity = layers.Input(name='connectivity',
                                shape=(2, ),
                                dtype='int32')

    gather_layer = GatherAtomToBond(index=1)
    o = gather_layer([atom, connectivity])
    assert o._keras_shape == (None, 5)

    x1 = np.random.rand(2, 5)
    x3 = np.array([[0, 1], [1, 0]])

    model = GraphModel([atom, connectivity], o)
    out = model.predict_on_batch({'atom': x1, 'connectivity': x3})

    assert_allclose(out[0], x1[1])
    assert_allclose(out[1], x1[0])