def message_block(original_atom_state, original_bond_state, connectivity, i): atom_state = original_atom_state bond_state = original_bond_state source_atom = nfp.Gather()( [atom_state, nfp.Slice(np.s_[:, :, 1])(connectivity)]) target_atom = nfp.Gather()( [atom_state, nfp.Slice(np.s_[:, :, 0])(connectivity)]) # Edge update network new_bond_state = layers.Concatenate(name='concat_{}'.format(i))( [source_atom, target_atom, bond_state]) new_bond_state = layers.Dense(2 * embed_dimension, activation='relu')(new_bond_state) new_bond_state = layers.Dense(embed_dimension)(new_bond_state) bond_state = layers.Add()([original_bond_state, new_bond_state]) # message function source_atom = layers.Dense(embed_dimension)(source_atom) messages = layers.Multiply()([source_atom, bond_state]) messages = nfp.Reduce(reduction='sum')( [messages, nfp.Slice(np.s_[:, :, 0])(connectivity), atom_state]) # state transition function messages = layers.Dense(embed_dimension, activation='relu')(messages) messages = layers.Dense(embed_dimension)(messages) atom_state = layers.Add()([original_atom_state, messages]) return atom_state, bond_state
def message_block(original_atom_state, original_bond_state, connectivity): """ Performs the graph-aware updates """ atom_state = layers.LayerNormalization()(original_atom_state) bond_state = layers.LayerNormalization()(original_bond_state) source_atom = nfp.Gather()([atom_state, nfp.Slice(np.s_[:, :, 1])(connectivity)]) target_atom = nfp.Gather()([atom_state, nfp.Slice(np.s_[:, :, 0])(connectivity)]) # Edge update network new_bond_state = layers.Concatenate()( [source_atom, target_atom, bond_state]) new_bond_state = layers.Dense( 2*atom_features, activation='relu')(new_bond_state) new_bond_state = layers.Dense(atom_features)(new_bond_state) bond_state = layers.Add()([original_bond_state, new_bond_state]) # message function source_atom = layers.Dense(atom_features)(source_atom) messages = layers.Multiply()([source_atom, bond_state]) messages = nfp.Reduce(reduction='sum')( [messages, nfp.Slice(np.s_[:, :, 0])(connectivity), atom_state]) # state transition function messages = layers.Dense(atom_features, activation='relu')(messages) messages = layers.Dense(atom_features)(messages) atom_state = layers.Add()([original_atom_state, messages]) return atom_state, bond_state,
def build(self, input_shape): """ inputs = [atom_state, bond_state, connectivity] shape(bond_state) = [batch, num_bonds, bond_features] """ super().build(input_shape) self.gather = nfp.Gather() self.slice1 = nfp.Slice(np.s_[:, :, 1]) self.slice0 = nfp.Slice(np.s_[:, :, 0]) self.concat = nfp.ConcatDense()
def build(self, input_shape): super().build(input_shape) num_features = input_shape[1][-1] self.gather = nfp.Gather() self.slice0 = nfp.Slice(np.s_[:, :, 0]) self.slice1 = nfp.Slice(np.s_[:, :, 1]) self.concat = nfp.ConcatDense() self.reduce = nfp.Reduce(reduction='sum') self.dense1 = layers.Dense(2 * num_features, activation='relu') self.dense2 = layers.Dense(num_features)
def test_slice(): connectivity = layers.Input(shape=[None, 2], dtype=tf.int64, name='connectivity') out0 = nfp.Slice(np.s_[:, :, 0])(connectivity) out1 = nfp.Slice(np.s_[:, :, 1])(connectivity) model = tf.keras.Model([connectivity], [out0, out1]) inputs = np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 0]]).T inputs = inputs[np.newaxis, :, :] out = model(inputs) assert_allclose(out[0], inputs[:, :, 0]) assert_allclose(out[1], inputs[:, :, 1])
def test_reduce(smiles_inputs, method): preprocessor, inputs = smiles_inputs func = getattr(np, method) atom_class = layers.Input(shape=[None], dtype=tf.int64, name='atom') bond_class = layers.Input(shape=[None], dtype=tf.int64, name='bond') connectivity = layers.Input(shape=[None, 2], dtype=tf.int64, name='connectivity') atom_embed = layers.Embedding(preprocessor.atom_classes, 16, mask_zero=True)(atom_class) bond_embed = layers.Embedding(preprocessor.bond_classes, 16, mask_zero=True)(bond_class) reduced = nfp.Reduce(method)( [bond_embed, nfp.Slice(np.s_[:, :, 0])(connectivity), atom_embed]) model = tf.keras.Model([atom_class, bond_class, connectivity], [atom_embed, bond_embed, reduced]) atom_state, bond_state, atom_reduced = model( [inputs['atom'], inputs['bond'], inputs['connectivity']]) assert_allclose(atom_reduced[0, 0, :], func(bond_state[0, :4, :], 0)) assert_allclose(atom_reduced[0, 1, :], func(bond_state[0, 4:8, :], 0)) assert_allclose(atom_reduced[0, 2, :], bond_state[0, 9, :], 0) assert_allclose(atom_reduced[0, 3, :], bond_state[0, 10, :], 0) assert_allclose(atom_reduced[0, 4, :], bond_state[0, 11, :], 0) assert_allclose(atom_reduced[0, 5, :], bond_state[0, 12, :], 0)