def message_block(atom_state, bond_state, connectivity): source_atom_gather = GatherAtomToBond(1) target_atom_gather = GatherAtomToBond(0) source_atom = source_atom_gather([atom_state, connectivity]) target_atom = target_atom_gather([atom_state, connectivity]) # Edge update network bond_state = Concatenate()([source_atom, target_atom, bond_state]) bond_state = Dense(2*atom_features, activation='softplus')(bond_state) bond_state = Dense(atom_features)(bond_state) # message function bond_state = Dense(atom_features, activation='softplus')(bond_state) bond_state = Dense(atom_features, activation='softplus')(bond_state) source_atom = Dense(atom_features)(source_atom) messages = Multiply()([source_atom, bond_state]) messages = ReduceBondToAtom(reducer='sum')([messages, connectivity]) # state transition function messages = Dense(atom_features, activation='softplus')(messages) messages = Dense(atom_features)(messages) atom_state = Add()([atom_state, messages]) return atom_state, bond_state
def message_block(original_atom_state, original_bond_state, connectivity, i): atom_state = BatchNormalization()(original_atom_state) bond_state = BatchNormalization()(original_bond_state) source_atom_gather = GatherAtomToBond(1) target_atom_gather = GatherAtomToBond(0) source_atom = source_atom_gather([atom_state, connectivity]) target_atom = target_atom_gather([atom_state, connectivity]) # Edge update network new_bond_state = Concatenate(name='concat_{}'.format(i))([ source_atom, target_atom, bond_state]) new_bond_state = Dense( 2*atom_features, activation='relu')(new_bond_state) new_bond_state = Dense(atom_features)(new_bond_state) bond_state = Add()([original_bond_state, new_bond_state]) # message function source_atom = Dense(atom_features)(source_atom) messages = Multiply()([source_atom, bond_state]) messages = ReduceBondToAtom(reducer='sum')([messages, connectivity]) # state transition function messages = Dense(atom_features, activation='relu')(messages) messages = Dense(atom_features)(messages) atom_state = Add()([original_atom_state, messages]) return atom_state, bond_state
def test_ReduceBondToAtom(): bond = layers.Input(name='bond', shape=(5, ), dtype='float32') connectivity = layers.Input(name='connectivity', shape=(2, ), dtype='int32') reduce_layer = ReduceBondToAtom(reducer='max') o = reduce_layer([bond, connectivity]) assert o._keras_shape == (None, 5) model = GraphModel([bond, connectivity], o) x1 = np.random.rand(5, 5) x2 = np.array([[0, 0, 0, 1, 1], [1, 1, 1, 1, 1]]).T out = model.predict_on_batch([x1, x2]) assert_allclose(x1[:3].max(0), out[0]) assert_allclose(x1[3:].max(0), out[1])