Exemplo n.º 1
0
def test_global(smiles_inputs, dropout):
    preprocessor, inputs = smiles_inputs

    atom_class = layers.Input(shape=[11], dtype=tf.int64, name='atom')
    bond_class = layers.Input(shape=[None], dtype=tf.int64, name='bond')
    connectivity = layers.Input(shape=[None, 2],
                                dtype=tf.int64,
                                name='connectivity')

    atom_state = layers.Embedding(preprocessor.atom_classes,
                                  16,
                                  mask_zero=True)(atom_class)
    bond_state = layers.Embedding(preprocessor.bond_classes,
                                  16,
                                  mask_zero=True)(bond_class)
    global_state = layers.GlobalAveragePooling1D()(atom_state)

    update = nfp.GlobalUpdate(
        8, 2, dropout=dropout)([atom_state, bond_state, connectivity])
    update_global = nfp.GlobalUpdate(8, 2, dropout=dropout)(
        [atom_state, bond_state, connectivity, global_state])

    model = tf.keras.Model([atom_class, bond_class, connectivity],
                           [update, update_global])

    update_state, update_state_global = model(inputs)

    assert not hasattr(update, '_keras_mask')
    assert not hasattr(update_global, '_keras_mask')
    assert update_state.shape == update_state_global.shape
    assert not np.all(update_state == update_state_global)
Exemplo n.º 2
0
def test_save_and_load_message(smiles_inputs, tmpdir: 'py.path.local'):
    preprocessor, inputs = smiles_inputs

    def get_inputs(max_atoms=-1, max_bonds=-1):
        dataset = tf.data.Dataset.from_generator(
            lambda: (preprocessor.construct_feature_matrices(smiles, train=True)
                     for smiles in ['CC', 'CCC', 'C(C)C', 'C']),
            output_types=preprocessor.output_types,
            output_shapes=preprocessor.output_shapes) \
            .padded_batch(batch_size=4,
                          padded_shapes=preprocessor.padded_shapes(max_atoms, max_bonds),
                          padding_values=preprocessor.padding_values)

        return list(dataset.take(1))[0]

    atom_class = layers.Input(shape=[None], dtype=tf.int64, name='atom')
    bond_class = layers.Input(shape=[None], dtype=tf.int64, name='bond')
    connectivity = layers.Input(shape=[None, 2],
                                dtype=tf.int64,
                                name='connectivity')

    atom_state = layers.Embedding(preprocessor.atom_classes,
                                  16,
                                  mask_zero=True)(atom_class)
    bond_state = layers.Embedding(preprocessor.bond_classes,
                                  16,
                                  mask_zero=True)(bond_class)
    global_state = nfp.GlobalUpdate(8,
                                    2)([atom_state, bond_state, connectivity])

    for _ in range(3):
        new_bond_state = nfp.EdgeUpdate()(
            [atom_state, bond_state, connectivity])
        bond_state = layers.Add()([bond_state, new_bond_state])

        new_atom_state = nfp.NodeUpdate()(
            [atom_state, bond_state, connectivity])
        atom_state = layers.Add()([atom_state, new_atom_state])

        new_global_state = nfp.GlobalUpdate(
            8, 2)([atom_state, bond_state, connectivity])
        global_state = layers.Add()([new_global_state, global_state])

    model = tf.keras.Model([atom_class, bond_class, connectivity],
                           [global_state])
    outputs = model(get_inputs())
    output_pad = model(get_inputs(max_atoms=20, max_bonds=40))
    assert np.all(np.isclose(outputs, output_pad, atol=1E-4, rtol=1E-4))

    model.save(tmpdir, include_optimizer=False)
    loaded_model = tf.keras.models.load_model(tmpdir, compile=False)
    loutputs = loaded_model(get_inputs())
    loutputs_pad = model(get_inputs(max_atoms=20, max_bonds=40))

    assert np.all(np.isclose(outputs, loutputs, atol=1E-4, rtol=1E-3))
    assert np.all(np.isclose(output_pad, loutputs_pad, atol=1E-4, rtol=1E-3))
Exemplo n.º 3
0
def build_fn(atom_features: int = 64,
             message_steps: int = 8,
             output_layers: List[int] = (512, 256, 128)):
    atom = layers.Input(shape=[None], dtype=tf.int64, name='atom')
    bond = layers.Input(shape=[None], dtype=tf.int64, name='bond')
    connectivity = layers.Input(shape=[None, 2],
                                dtype=tf.int64,
                                name='connectivity')

    # Convert from a single integer defining the atom state to a vector
    # of weights associated with that class
    atom_state = layers.Embedding(36,
                                  atom_features,
                                  name='atom_embedding',
                                  mask_zero=True)(atom)

    # Ditto with the bond state
    bond_state = layers.Embedding(5,
                                  atom_features,
                                  name='bond_embedding',
                                  mask_zero=True)(bond)

    # Here we use our first nfp layer. This is an attention layer that looks at
    # the atom and bond states and reduces them to a single, graph-level vector.
    # mum_heads * units has to be the same dimension as the atom / bond dimension
    global_state = nfp.GlobalUpdate(units=4, num_heads=1, name='problem')(
        [atom_state, bond_state, connectivity])

    for _ in range(message_steps):  # Do the message passing
        new_bond_state = nfp.EdgeUpdate()(
            [atom_state, bond_state, connectivity, global_state])
        bond_state = layers.Add()([bond_state, new_bond_state])

        new_atom_state = nfp.NodeUpdate()(
            [atom_state, bond_state, connectivity, global_state])
        atom_state = layers.Add()([atom_state, new_atom_state])

        new_global_state = nfp.GlobalUpdate(units=4, num_heads=1)(
            [atom_state, bond_state, connectivity, global_state])
        global_state = layers.Add()([global_state, new_global_state])

    # Pass the global state through an output
    output = global_state
    for shape in output_layers:
        output = layers.Dense(shape, activation='relu')(output)
    output = layers.Dense(1)(output)
    output = layers.Dense(1, activation='linear', name='scale')(output)

    # Construct the tf.keras model
    return tf.keras.Model([atom, bond, connectivity], [output])
Exemplo n.º 4
0
def policy_model():

    # Define inputs
    atom_class = layers.Input(shape=[None], dtype=tf.int64,
                              name='atom')  # batch_size, num_atoms
    bond_class = layers.Input(shape=[None], dtype=tf.int64,
                              name='bond')  # batch_size, num_bonds
    connectivity = layers.Input(
        shape=[None, 2], dtype=tf.int64,
        name='connectivity')  # batch_size, num_bonds, 2

    input_tensors = [atom_class, bond_class, connectivity]

    # Initialize the atom states
    atom_state = layers.Embedding(preprocessor.atom_classes,
                                  config.features,
                                  name='atom_embedding',
                                  mask_zero=True)(atom_class)

    # Initialize the bond states
    bond_state = layers.Embedding(preprocessor.bond_classes,
                                  config.features,
                                  name='bond_embedding',
                                  mask_zero=True)(bond_class)

    units = config.features // config.num_heads
    global_state = nfp.GlobalUpdate(units=units, num_heads=config.num_heads)(
        [atom_state, bond_state, connectivity])

    for _ in range(config.num_messages):  # Do the message passing
        bond_state = nfp.EdgeUpdate()(
            [atom_state, bond_state, connectivity, global_state])
        atom_state = nfp.NodeUpdate()(
            [atom_state, bond_state, connectivity, global_state])
        global_state = nfp.GlobalUpdate(units=units,
                                        num_heads=config.num_heads)([
                                            atom_state, bond_state,
                                            connectivity, global_state
                                        ])

    value_logit = layers.Dense(1)(global_state)
    pi_logit = layers.Dense(1)(global_state)

    return tf.keras.Model(input_tensors, [value_logit, pi_logit],
                          name='policy_model')
Exemplo n.º 5
0
def test_no_residual(smiles_inputs):
    preprocessor, inputs = smiles_inputs

    def get_inputs(max_atoms=-1, max_bonds=-1):
        dataset = tf.data.Dataset.from_generator(
            lambda: (preprocessor.construct_feature_matrices(smiles, train=True)
                     for smiles in ['CC', 'CCC', 'C(C)C', 'C']),
            output_types=preprocessor.output_types,
            output_shapes=preprocessor.output_shapes) \
            .padded_batch(batch_size=4,
                          padded_shapes=preprocessor.padded_shapes(max_atoms, max_bonds),
                          padding_values=preprocessor.padding_values)

        return list(dataset.take(1))[0]

    atom_class = layers.Input(shape=[None], dtype=tf.int64, name='atom')
    bond_class = layers.Input(shape=[None], dtype=tf.int64, name='bond')
    connectivity = layers.Input(shape=[None, 2],
                                dtype=tf.int64,
                                name='connectivity')

    atom_state = layers.Embedding(preprocessor.atom_classes,
                                  16,
                                  mask_zero=True)(atom_class)
    bond_state = layers.Embedding(preprocessor.bond_classes,
                                  16,
                                  mask_zero=True)(bond_class)
    global_state = nfp.GlobalUpdate(8,
                                    2)([atom_state, bond_state, connectivity])

    for _ in range(3):
        bond_state = nfp.EdgeUpdate()([atom_state, bond_state, connectivity])
        atom_state = nfp.NodeUpdate()([atom_state, bond_state, connectivity])
        global_state = nfp.GlobalUpdate(
            8, 2)([atom_state, bond_state, connectivity])

    model = tf.keras.Model([atom_class, bond_class, connectivity],
                           [global_state])

    output = model(get_inputs())
    output_pad = model(get_inputs(max_atoms=20, max_bonds=40))

    assert np.all(np.isclose(output, output_pad, atol=1E-4))
Exemplo n.º 6
0
    def message_block(atom_state, bond_state, connectivity, global_state, i):

        new_bond_state = nfp.EdgeUpdate()(
            [atom_state, bond_state, connectivity])
        bond_state = layers.Add()([bond_state, new_bond_state])

        new_atom_state = nfp.NodeUpdate()(
            [atom_state, bond_state, connectivity])
        atom_state = layers.Add()([atom_state, new_atom_state])

        new_global_state = nfp.GlobalUpdate(
            head_features, num_heads)([atom_state, bond_state, connectivity])
        global_state = layers.Add()([new_global_state, global_state])

        return atom_state, bond_state, global_state
Exemplo n.º 7
0
# of weights associated with that class
atom_state = layers.Embedding(preprocessor.atom_classes,
                              num_features,
                              name='atom_embedding',
                              mask_zero=True)(atom)

# Ditto with the bond state
bond_state = layers.Embedding(preprocessor.bond_classes,
                              num_features,
                              name='bond_embedding',
                              mask_zero=True)(bond)

# Here we use our first nfp layer. This is an attention layer that looks at
# the atom and bond states and reduces them to a single, graph-level vector.
# mum_heads * units has to be the same dimension as the atom / bond dimension
global_state = nfp.GlobalUpdate(
    units=units, num_heads=heads)([atom_state, bond_state, connectivity])

for _ in range(3):  # Do the message passing
    new_bond_state = nfp.EdgeUpdate()(
        [atom_state, bond_state, connectivity, global_state])
    bond_state = layers.Add()([bond_state, new_bond_state])

    new_atom_state = nfp.NodeUpdate()(
        [atom_state, bond_state, connectivity, global_state])
    atom_state = layers.Add()([atom_state, new_atom_state])

    new_global_state = nfp.GlobalUpdate(units=units, num_heads=heads)(
        [atom_state, bond_state, connectivity, global_state])
    global_state = layers.Add()([global_state, new_global_state])

# Since the final prediction is a single, molecule-level property (YSI), we
Exemplo n.º 8
0
def build_embedding_model(preprocessor,
                          dropout=0.0,
                          atom_features=128,
                          num_messages=6,
                          num_heads=8,
                          name='atom_embedding_model'):

    assert atom_features % num_heads == 0, "Wrong feature / head dimension"
    head_features = atom_features // num_heads

    # Define keras model
    n_atom = layers.Input(shape=[], dtype=tf.int64, name='n_atom')
    atom_class = layers.Input(shape=[None], dtype=tf.int64, name='atom')
    bond_class = layers.Input(shape=[None], dtype=tf.int64, name='bond')
    connectivity = layers.Input(shape=[None, 2],
                                dtype=tf.int64,
                                name='connectivity')

    input_tensors = [atom_class, bond_class, connectivity, n_atom]

    # Initialize the atom states
    atom_state = layers.Embedding(preprocessor.atom_classes,
                                  atom_features,
                                  name='atom_embedding',
                                  mask_zero=True)(atom_class)

    # Initialize the bond states
    bond_state = layers.Embedding(preprocessor.bond_classes,
                                  atom_features,
                                  name='bond_embedding',
                                  mask_zero=True)(bond_class)

    global_state = nfp.GlobalUpdate(
        head_features, num_heads)([atom_state, bond_state, connectivity])

    def message_block(atom_state, bond_state, connectivity, global_state, i):

        new_bond_state = nfp.EdgeUpdate()(
            [atom_state, bond_state, connectivity])
        bond_state = layers.Add()([bond_state, new_bond_state])

        new_atom_state = nfp.NodeUpdate()(
            [atom_state, bond_state, connectivity])
        atom_state = layers.Add()([atom_state, new_atom_state])

        new_global_state = nfp.GlobalUpdate(
            head_features, num_heads)([atom_state, bond_state, connectivity])
        global_state = layers.Add()([new_global_state, global_state])

        return atom_state, bond_state, global_state

    atom_states = [atom_state]
    bond_states = [bond_state]
    global_states = [global_state]

    for i in range(num_messages):
        atom_state, bond_state, global_state = message_block(
            atom_state, bond_state, connectivity, global_state, i)

        atom_states += [atom_state]
        bond_states += [bond_state]
        global_states += [global_state]

#    atom_embedding_model = tf.keras.Model(input_tensors, [atom_state, bond_state, global_state], name=name)

    return input_tensors, atom_states, bond_states, global_states