예제 #1
0
def test_save_and_load_message(smiles_inputs, tmpdir: 'py.path.local'):
    preprocessor, inputs = smiles_inputs

    def get_inputs(max_atoms=-1, max_bonds=-1):
        dataset = tf.data.Dataset.from_generator(
            lambda: (preprocessor.construct_feature_matrices(smiles, train=True)
                     for smiles in ['CC', 'CCC', 'C(C)C', 'C']),
            output_types=preprocessor.output_types,
            output_shapes=preprocessor.output_shapes) \
            .padded_batch(batch_size=4,
                          padded_shapes=preprocessor.padded_shapes(max_atoms, max_bonds),
                          padding_values=preprocessor.padding_values)

        return list(dataset.take(1))[0]

    atom_class = layers.Input(shape=[None], dtype=tf.int64, name='atom')
    bond_class = layers.Input(shape=[None], dtype=tf.int64, name='bond')
    connectivity = layers.Input(shape=[None, 2],
                                dtype=tf.int64,
                                name='connectivity')

    atom_state = layers.Embedding(preprocessor.atom_classes,
                                  16,
                                  mask_zero=True)(atom_class)
    bond_state = layers.Embedding(preprocessor.bond_classes,
                                  16,
                                  mask_zero=True)(bond_class)
    global_state = nfp.GlobalUpdate(8,
                                    2)([atom_state, bond_state, connectivity])

    for _ in range(3):
        new_bond_state = nfp.EdgeUpdate()(
            [atom_state, bond_state, connectivity])
        bond_state = layers.Add()([bond_state, new_bond_state])

        new_atom_state = nfp.NodeUpdate()(
            [atom_state, bond_state, connectivity])
        atom_state = layers.Add()([atom_state, new_atom_state])

        new_global_state = nfp.GlobalUpdate(
            8, 2)([atom_state, bond_state, connectivity])
        global_state = layers.Add()([new_global_state, global_state])

    model = tf.keras.Model([atom_class, bond_class, connectivity],
                           [global_state])
    outputs = model(get_inputs())
    output_pad = model(get_inputs(max_atoms=20, max_bonds=40))
    assert np.all(np.isclose(outputs, output_pad, atol=1E-4, rtol=1E-4))

    model.save(tmpdir, include_optimizer=False)
    loaded_model = tf.keras.models.load_model(tmpdir, compile=False)
    loutputs = loaded_model(get_inputs())
    loutputs_pad = model(get_inputs(max_atoms=20, max_bonds=40))

    assert np.all(np.isclose(outputs, loutputs, atol=1E-4, rtol=1E-3))
    assert np.all(np.isclose(output_pad, loutputs_pad, atol=1E-4, rtol=1E-3))
예제 #2
0
def build_fn(atom_features: int = 64,
             message_steps: int = 8,
             output_layers: List[int] = (512, 256, 128)):
    atom = layers.Input(shape=[None], dtype=tf.int64, name='atom')
    bond = layers.Input(shape=[None], dtype=tf.int64, name='bond')
    connectivity = layers.Input(shape=[None, 2],
                                dtype=tf.int64,
                                name='connectivity')

    # Convert from a single integer defining the atom state to a vector
    # of weights associated with that class
    atom_state = layers.Embedding(36,
                                  atom_features,
                                  name='atom_embedding',
                                  mask_zero=True)(atom)

    # Ditto with the bond state
    bond_state = layers.Embedding(5,
                                  atom_features,
                                  name='bond_embedding',
                                  mask_zero=True)(bond)

    # Here we use our first nfp layer. This is an attention layer that looks at
    # the atom and bond states and reduces them to a single, graph-level vector.
    # mum_heads * units has to be the same dimension as the atom / bond dimension
    global_state = nfp.GlobalUpdate(units=4, num_heads=1, name='problem')(
        [atom_state, bond_state, connectivity])

    for _ in range(message_steps):  # Do the message passing
        new_bond_state = nfp.EdgeUpdate()(
            [atom_state, bond_state, connectivity, global_state])
        bond_state = layers.Add()([bond_state, new_bond_state])

        new_atom_state = nfp.NodeUpdate()(
            [atom_state, bond_state, connectivity, global_state])
        atom_state = layers.Add()([atom_state, new_atom_state])

        new_global_state = nfp.GlobalUpdate(units=4, num_heads=1)(
            [atom_state, bond_state, connectivity, global_state])
        global_state = layers.Add()([global_state, new_global_state])

    # Pass the global state through an output
    output = global_state
    for shape in output_layers:
        output = layers.Dense(shape, activation='relu')(output)
    output = layers.Dense(1)(output)
    output = layers.Dense(1, activation='linear', name='scale')(output)

    # Construct the tf.keras model
    return tf.keras.Model([atom, bond, connectivity], [output])
예제 #3
0
파일: model.py 프로젝트: pstjohn/spin_gnn
    def message_block(atom_state, bond_state, connectivity, global_state, i):

        new_bond_state = nfp.EdgeUpdate()(
            [atom_state, bond_state, connectivity])
        bond_state = layers.Add()([bond_state, new_bond_state])

        new_atom_state = nfp.NodeUpdate()(
            [atom_state, bond_state, connectivity])
        atom_state = layers.Add()([atom_state, new_atom_state])

        new_global_state = nfp.GlobalUpdate(
            head_features, num_heads)([atom_state, bond_state, connectivity])
        global_state = layers.Add()([new_global_state, global_state])

        return atom_state, bond_state, global_state
예제 #4
0
파일: policy.py 프로젝트: dmdu/rlmolecule
def policy_model():

    # Define inputs
    atom_class = layers.Input(shape=[None], dtype=tf.int64,
                              name='atom')  # batch_size, num_atoms
    bond_class = layers.Input(shape=[None], dtype=tf.int64,
                              name='bond')  # batch_size, num_bonds
    connectivity = layers.Input(
        shape=[None, 2], dtype=tf.int64,
        name='connectivity')  # batch_size, num_bonds, 2

    input_tensors = [atom_class, bond_class, connectivity]

    # Initialize the atom states
    atom_state = layers.Embedding(preprocessor.atom_classes,
                                  config.features,
                                  name='atom_embedding',
                                  mask_zero=True)(atom_class)

    # Initialize the bond states
    bond_state = layers.Embedding(preprocessor.bond_classes,
                                  config.features,
                                  name='bond_embedding',
                                  mask_zero=True)(bond_class)

    units = config.features // config.num_heads
    global_state = nfp.GlobalUpdate(units=units, num_heads=config.num_heads)(
        [atom_state, bond_state, connectivity])

    for _ in range(config.num_messages):  # Do the message passing
        bond_state = nfp.EdgeUpdate()(
            [atom_state, bond_state, connectivity, global_state])
        atom_state = nfp.NodeUpdate()(
            [atom_state, bond_state, connectivity, global_state])
        global_state = nfp.GlobalUpdate(units=units,
                                        num_heads=config.num_heads)([
                                            atom_state, bond_state,
                                            connectivity, global_state
                                        ])

    value_logit = layers.Dense(1)(global_state)
    pi_logit = layers.Dense(1)(global_state)

    return tf.keras.Model(input_tensors, [value_logit, pi_logit],
                          name='policy_model')
예제 #5
0
def test_no_residual(smiles_inputs):
    preprocessor, inputs = smiles_inputs

    def get_inputs(max_atoms=-1, max_bonds=-1):
        dataset = tf.data.Dataset.from_generator(
            lambda: (preprocessor.construct_feature_matrices(smiles, train=True)
                     for smiles in ['CC', 'CCC', 'C(C)C', 'C']),
            output_types=preprocessor.output_types,
            output_shapes=preprocessor.output_shapes) \
            .padded_batch(batch_size=4,
                          padded_shapes=preprocessor.padded_shapes(max_atoms, max_bonds),
                          padding_values=preprocessor.padding_values)

        return list(dataset.take(1))[0]

    atom_class = layers.Input(shape=[None], dtype=tf.int64, name='atom')
    bond_class = layers.Input(shape=[None], dtype=tf.int64, name='bond')
    connectivity = layers.Input(shape=[None, 2],
                                dtype=tf.int64,
                                name='connectivity')

    atom_state = layers.Embedding(preprocessor.atom_classes,
                                  16,
                                  mask_zero=True)(atom_class)
    bond_state = layers.Embedding(preprocessor.bond_classes,
                                  16,
                                  mask_zero=True)(bond_class)
    global_state = nfp.GlobalUpdate(8,
                                    2)([atom_state, bond_state, connectivity])

    for _ in range(3):
        bond_state = nfp.EdgeUpdate()([atom_state, bond_state, connectivity])
        atom_state = nfp.NodeUpdate()([atom_state, bond_state, connectivity])
        global_state = nfp.GlobalUpdate(
            8, 2)([atom_state, bond_state, connectivity])

    model = tf.keras.Model([atom_class, bond_class, connectivity],
                           [global_state])

    output = model(get_inputs())
    output_pad = model(get_inputs(max_atoms=20, max_bonds=40))

    assert np.all(np.isclose(output, output_pad, atol=1E-4))
예제 #6
0
                              mask_zero=True)(atom)

# Ditto with the bond state
bond_state = layers.Embedding(preprocessor.bond_classes,
                              num_features,
                              name='bond_embedding',
                              mask_zero=True)(bond)

# Here we use our first nfp layer. This is an attention layer that looks at
# the atom and bond states and reduces them to a single, graph-level vector.
# mum_heads * units has to be the same dimension as the atom / bond dimension
global_state = nfp.GlobalUpdate(
    units=units, num_heads=heads)([atom_state, bond_state, connectivity])

for _ in range(3):  # Do the message passing
    new_bond_state = nfp.EdgeUpdate()(
        [atom_state, bond_state, connectivity, global_state])
    bond_state = layers.Add()([bond_state, new_bond_state])

    new_atom_state = nfp.NodeUpdate()(
        [atom_state, bond_state, connectivity, global_state])
    atom_state = layers.Add()([atom_state, new_atom_state])

    new_global_state = nfp.GlobalUpdate(units=units, num_heads=heads)(
        [atom_state, bond_state, connectivity, global_state])
    global_state = layers.Add()([global_state, new_global_state])

# Since the final prediction is a single, molecule-level property (YSI), we
# reduce the last global state to a single prediction.
fp_out = layers.Dense(fp_size)(global_state)
param_prediction = layers.Dense(1)(global_state)