def message_block(original_atom_state, original_bond_state, connectivity, i): atom_state = original_atom_state bond_state = original_bond_state source_atom = nfp.Gather()( [atom_state, nfp.Slice(np.s_[:, :, 1])(connectivity)]) target_atom = nfp.Gather()( [atom_state, nfp.Slice(np.s_[:, :, 0])(connectivity)]) # Edge update network new_bond_state = layers.Concatenate(name='concat_{}'.format(i))( [source_atom, target_atom, bond_state]) new_bond_state = layers.Dense(2 * embed_dimension, activation='relu')(new_bond_state) new_bond_state = layers.Dense(embed_dimension)(new_bond_state) bond_state = layers.Add()([original_bond_state, new_bond_state]) # message function source_atom = layers.Dense(embed_dimension)(source_atom) messages = layers.Multiply()([source_atom, bond_state]) messages = nfp.Reduce(reduction='sum')( [messages, nfp.Slice(np.s_[:, :, 0])(connectivity), atom_state]) # state transition function messages = layers.Dense(embed_dimension, activation='relu')(messages) messages = layers.Dense(embed_dimension)(messages) atom_state = layers.Add()([original_atom_state, messages]) return atom_state, bond_state
def message_block(original_atom_state, original_bond_state, connectivity): """ Performs the graph-aware updates """ atom_state = layers.LayerNormalization()(original_atom_state) bond_state = layers.LayerNormalization()(original_bond_state) source_atom = nfp.Gather()([atom_state, nfp.Slice(np.s_[:, :, 1])(connectivity)]) target_atom = nfp.Gather()([atom_state, nfp.Slice(np.s_[:, :, 0])(connectivity)]) # Edge update network new_bond_state = layers.Concatenate()( [source_atom, target_atom, bond_state]) new_bond_state = layers.Dense( 2*atom_features, activation='relu')(new_bond_state) new_bond_state = layers.Dense(atom_features)(new_bond_state) bond_state = layers.Add()([original_bond_state, new_bond_state]) # message function source_atom = layers.Dense(atom_features)(source_atom) messages = layers.Multiply()([source_atom, bond_state]) messages = nfp.Reduce(reduction='sum')( [messages, nfp.Slice(np.s_[:, :, 0])(connectivity), atom_state]) # state transition function messages = layers.Dense(atom_features, activation='relu')(messages) messages = layers.Dense(atom_features)(messages) atom_state = layers.Add()([original_atom_state, messages]) return atom_state, bond_state,
def build(self, input_shape): """ inputs = [atom_state, bond_state, connectivity] shape(bond_state) = [batch, num_bonds, bond_features] """ super().build(input_shape) self.gather = nfp.Gather() self.slice1 = nfp.Slice(np.s_[:, :, 1]) self.slice0 = nfp.Slice(np.s_[:, :, 0]) self.concat = nfp.ConcatDense()
def test_gather(): in1 = layers.Input(shape=[None], dtype='float', name='data') in2 = layers.Input(shape=[None], dtype=tf.int64, name='indices') gather = nfp.Gather()([in1, in2]) model = tf.keras.Model([in1, in2], [gather]) data = np.random.rand(2, 10).astype(np.float32) indices = np.array([[2, 6, 3], [5, 1, 0]]) out = model([data, indices]) assert_allclose(out, np.vstack([data[0, indices[0]], data[1, indices[1]]]))
def build(self, input_shape): super().build(input_shape) num_features = input_shape[1][-1] self.gather = nfp.Gather() self.slice0 = nfp.Slice(np.s_[:, :, 0]) self.slice1 = nfp.Slice(np.s_[:, :, 1]) self.concat = nfp.ConcatDense() self.reduce = nfp.Reduce(reduction='sum') self.dense1 = layers.Dense(2 * num_features, activation='relu') self.dense2 = layers.Dense(num_features)