예제 #1
0
def message_block(atom_state, bond_state, connectivity):

    source_atom_gather = GatherAtomToBond(1)
    target_atom_gather = GatherAtomToBond(0)

    source_atom = source_atom_gather([atom_state, connectivity])
    target_atom = target_atom_gather([atom_state, connectivity])

    # Edge update network
    bond_state = Concatenate()([source_atom, target_atom, bond_state])
    bond_state = Dense(2*atom_features, activation='softplus')(bond_state)
    bond_state = Dense(atom_features)(bond_state)

    # message function
    bond_state = Dense(atom_features, activation='softplus')(bond_state)
    bond_state = Dense(atom_features, activation='softplus')(bond_state)
    source_atom = Dense(atom_features)(source_atom)    
    messages = Multiply()([source_atom, bond_state])
    messages = ReduceBondToAtom(reducer='sum')([messages, connectivity])
    
    # state transition function
    messages = Dense(atom_features, activation='softplus')(messages)
    messages = Dense(atom_features)(messages)
    atom_state = Add()([atom_state, messages])
    
    return atom_state, bond_state
예제 #2
0
def message_block(original_atom_state, original_bond_state, connectivity, i):
    
    atom_state = BatchNormalization()(original_atom_state)
    bond_state = BatchNormalization()(original_bond_state)
    
    source_atom_gather = GatherAtomToBond(1)
    target_atom_gather = GatherAtomToBond(0)

    source_atom = source_atom_gather([atom_state, connectivity])
    target_atom = target_atom_gather([atom_state, connectivity])

    # Edge update network
    new_bond_state = Concatenate(name='concat_{}'.format(i))([
        source_atom, target_atom, bond_state])
    new_bond_state = Dense(
        2*atom_features, activation='relu')(new_bond_state)
    new_bond_state = Dense(atom_features)(new_bond_state)

    bond_state = Add()([original_bond_state, new_bond_state])

    # message function
    source_atom = Dense(atom_features)(source_atom)    
    messages = Multiply()([source_atom, bond_state])
    messages = ReduceBondToAtom(reducer='sum')([messages, connectivity])
    
    # state transition function
    messages = Dense(atom_features, activation='relu')(messages)
    messages = Dense(atom_features)(messages)
    
    atom_state = Add()([original_atom_state, messages])
    
    return atom_state, bond_state
예제 #3
0
def test_ReduceBondToAtom():
    bond = layers.Input(name='bond', shape=(5, ), dtype='float32')
    connectivity = layers.Input(name='connectivity',
                                shape=(2, ),
                                dtype='int32')

    reduce_layer = ReduceBondToAtom(reducer='max')
    o = reduce_layer([bond, connectivity])
    assert o._keras_shape == (None, 5)

    model = GraphModel([bond, connectivity], o)

    x1 = np.random.rand(5, 5)
    x2 = np.array([[0, 0, 0, 1, 1], [1, 1, 1, 1, 1]]).T

    out = model.predict_on_batch([x1, x2])

    assert_allclose(x1[:3].max(0), out[0])
    assert_allclose(x1[3:].max(0), out[1])