def test_save_and_load_message(smiles_inputs, tmpdir: 'py.path.local'): preprocessor, inputs = smiles_inputs def get_inputs(max_atoms=-1, max_bonds=-1): dataset = tf.data.Dataset.from_generator( lambda: (preprocessor.construct_feature_matrices(smiles, train=True) for smiles in ['CC', 'CCC', 'C(C)C', 'C']), output_types=preprocessor.output_types, output_shapes=preprocessor.output_shapes) \ .padded_batch(batch_size=4, padded_shapes=preprocessor.padded_shapes(max_atoms, max_bonds), padding_values=preprocessor.padding_values) return list(dataset.take(1))[0] atom_class = layers.Input(shape=[None], dtype=tf.int64, name='atom') bond_class = layers.Input(shape=[None], dtype=tf.int64, name='bond') connectivity = layers.Input(shape=[None, 2], dtype=tf.int64, name='connectivity') atom_state = layers.Embedding(preprocessor.atom_classes, 16, mask_zero=True)(atom_class) bond_state = layers.Embedding(preprocessor.bond_classes, 16, mask_zero=True)(bond_class) global_state = nfp.GlobalUpdate(8, 2)([atom_state, bond_state, connectivity]) for _ in range(3): new_bond_state = nfp.EdgeUpdate()( [atom_state, bond_state, connectivity]) bond_state = layers.Add()([bond_state, new_bond_state]) new_atom_state = nfp.NodeUpdate()( [atom_state, bond_state, connectivity]) atom_state = layers.Add()([atom_state, new_atom_state]) new_global_state = nfp.GlobalUpdate( 8, 2)([atom_state, bond_state, connectivity]) global_state = layers.Add()([new_global_state, global_state]) model = tf.keras.Model([atom_class, bond_class, connectivity], [global_state]) outputs = model(get_inputs()) output_pad = model(get_inputs(max_atoms=20, max_bonds=40)) assert np.all(np.isclose(outputs, output_pad, atol=1E-4, rtol=1E-4)) model.save(tmpdir, include_optimizer=False) loaded_model = tf.keras.models.load_model(tmpdir, compile=False) loutputs = loaded_model(get_inputs()) loutputs_pad = model(get_inputs(max_atoms=20, max_bonds=40)) assert np.all(np.isclose(outputs, loutputs, atol=1E-4, rtol=1E-3)) assert np.all(np.isclose(output_pad, loutputs_pad, atol=1E-4, rtol=1E-3))
def build_fn(atom_features: int = 64, message_steps: int = 8, output_layers: List[int] = (512, 256, 128)): atom = layers.Input(shape=[None], dtype=tf.int64, name='atom') bond = layers.Input(shape=[None], dtype=tf.int64, name='bond') connectivity = layers.Input(shape=[None, 2], dtype=tf.int64, name='connectivity') # Convert from a single integer defining the atom state to a vector # of weights associated with that class atom_state = layers.Embedding(36, atom_features, name='atom_embedding', mask_zero=True)(atom) # Ditto with the bond state bond_state = layers.Embedding(5, atom_features, name='bond_embedding', mask_zero=True)(bond) # Here we use our first nfp layer. This is an attention layer that looks at # the atom and bond states and reduces them to a single, graph-level vector. # mum_heads * units has to be the same dimension as the atom / bond dimension global_state = nfp.GlobalUpdate(units=4, num_heads=1, name='problem')( [atom_state, bond_state, connectivity]) for _ in range(message_steps): # Do the message passing new_bond_state = nfp.EdgeUpdate()( [atom_state, bond_state, connectivity, global_state]) bond_state = layers.Add()([bond_state, new_bond_state]) new_atom_state = nfp.NodeUpdate()( [atom_state, bond_state, connectivity, global_state]) atom_state = layers.Add()([atom_state, new_atom_state]) new_global_state = nfp.GlobalUpdate(units=4, num_heads=1)( [atom_state, bond_state, connectivity, global_state]) global_state = layers.Add()([global_state, new_global_state]) # Pass the global state through an output output = global_state for shape in output_layers: output = layers.Dense(shape, activation='relu')(output) output = layers.Dense(1)(output) output = layers.Dense(1, activation='linear', name='scale')(output) # Construct the tf.keras model return tf.keras.Model([atom, bond, connectivity], [output])
def message_block(atom_state, bond_state, connectivity, global_state, i): new_bond_state = nfp.EdgeUpdate()( [atom_state, bond_state, connectivity]) bond_state = layers.Add()([bond_state, new_bond_state]) new_atom_state = nfp.NodeUpdate()( [atom_state, bond_state, connectivity]) atom_state = layers.Add()([atom_state, new_atom_state]) new_global_state = nfp.GlobalUpdate( head_features, num_heads)([atom_state, bond_state, connectivity]) global_state = layers.Add()([new_global_state, global_state]) return atom_state, bond_state, global_state
def policy_model(): # Define inputs atom_class = layers.Input(shape=[None], dtype=tf.int64, name='atom') # batch_size, num_atoms bond_class = layers.Input(shape=[None], dtype=tf.int64, name='bond') # batch_size, num_bonds connectivity = layers.Input( shape=[None, 2], dtype=tf.int64, name='connectivity') # batch_size, num_bonds, 2 input_tensors = [atom_class, bond_class, connectivity] # Initialize the atom states atom_state = layers.Embedding(preprocessor.atom_classes, config.features, name='atom_embedding', mask_zero=True)(atom_class) # Initialize the bond states bond_state = layers.Embedding(preprocessor.bond_classes, config.features, name='bond_embedding', mask_zero=True)(bond_class) units = config.features // config.num_heads global_state = nfp.GlobalUpdate(units=units, num_heads=config.num_heads)( [atom_state, bond_state, connectivity]) for _ in range(config.num_messages): # Do the message passing bond_state = nfp.EdgeUpdate()( [atom_state, bond_state, connectivity, global_state]) atom_state = nfp.NodeUpdate()( [atom_state, bond_state, connectivity, global_state]) global_state = nfp.GlobalUpdate(units=units, num_heads=config.num_heads)([ atom_state, bond_state, connectivity, global_state ]) value_logit = layers.Dense(1)(global_state) pi_logit = layers.Dense(1)(global_state) return tf.keras.Model(input_tensors, [value_logit, pi_logit], name='policy_model')
def test_no_residual(smiles_inputs): preprocessor, inputs = smiles_inputs def get_inputs(max_atoms=-1, max_bonds=-1): dataset = tf.data.Dataset.from_generator( lambda: (preprocessor.construct_feature_matrices(smiles, train=True) for smiles in ['CC', 'CCC', 'C(C)C', 'C']), output_types=preprocessor.output_types, output_shapes=preprocessor.output_shapes) \ .padded_batch(batch_size=4, padded_shapes=preprocessor.padded_shapes(max_atoms, max_bonds), padding_values=preprocessor.padding_values) return list(dataset.take(1))[0] atom_class = layers.Input(shape=[None], dtype=tf.int64, name='atom') bond_class = layers.Input(shape=[None], dtype=tf.int64, name='bond') connectivity = layers.Input(shape=[None, 2], dtype=tf.int64, name='connectivity') atom_state = layers.Embedding(preprocessor.atom_classes, 16, mask_zero=True)(atom_class) bond_state = layers.Embedding(preprocessor.bond_classes, 16, mask_zero=True)(bond_class) global_state = nfp.GlobalUpdate(8, 2)([atom_state, bond_state, connectivity]) for _ in range(3): bond_state = nfp.EdgeUpdate()([atom_state, bond_state, connectivity]) atom_state = nfp.NodeUpdate()([atom_state, bond_state, connectivity]) global_state = nfp.GlobalUpdate( 8, 2)([atom_state, bond_state, connectivity]) model = tf.keras.Model([atom_class, bond_class, connectivity], [global_state]) output = model(get_inputs()) output_pad = model(get_inputs(max_atoms=20, max_bonds=40)) assert np.all(np.isclose(output, output_pad, atol=1E-4))
mask_zero=True)(atom) # Ditto with the bond state bond_state = layers.Embedding(preprocessor.bond_classes, num_features, name='bond_embedding', mask_zero=True)(bond) # Here we use our first nfp layer. This is an attention layer that looks at # the atom and bond states and reduces them to a single, graph-level vector. # mum_heads * units has to be the same dimension as the atom / bond dimension global_state = nfp.GlobalUpdate( units=units, num_heads=heads)([atom_state, bond_state, connectivity]) for _ in range(3): # Do the message passing new_bond_state = nfp.EdgeUpdate()( [atom_state, bond_state, connectivity, global_state]) bond_state = layers.Add()([bond_state, new_bond_state]) new_atom_state = nfp.NodeUpdate()( [atom_state, bond_state, connectivity, global_state]) atom_state = layers.Add()([atom_state, new_atom_state]) new_global_state = nfp.GlobalUpdate(units=units, num_heads=heads)( [atom_state, bond_state, connectivity, global_state]) global_state = layers.Add()([global_state, new_global_state]) # Since the final prediction is a single, molecule-level property (YSI), we # reduce the last global state to a single prediction. fp_out = layers.Dense(fp_size)(global_state) param_prediction = layers.Dense(1)(global_state)