Exemplo n.º 1
0
def message_block(original_atom_state, original_bond_state, connectivity, i):

    atom_state = original_atom_state
    bond_state = original_bond_state

    source_atom = nfp.Gather()(
        [atom_state, nfp.Slice(np.s_[:, :, 1])(connectivity)])
    target_atom = nfp.Gather()(
        [atom_state, nfp.Slice(np.s_[:, :, 0])(connectivity)])

    # Edge update network
    new_bond_state = layers.Concatenate(name='concat_{}'.format(i))(
        [source_atom, target_atom, bond_state])
    new_bond_state = layers.Dense(2 * embed_dimension,
                                  activation='relu')(new_bond_state)
    new_bond_state = layers.Dense(embed_dimension)(new_bond_state)

    bond_state = layers.Add()([original_bond_state, new_bond_state])

    # message function
    source_atom = layers.Dense(embed_dimension)(source_atom)
    messages = layers.Multiply()([source_atom, bond_state])
    messages = nfp.Reduce(reduction='sum')(
        [messages,
         nfp.Slice(np.s_[:, :, 0])(connectivity), atom_state])

    # state transition function
    messages = layers.Dense(embed_dimension, activation='relu')(messages)
    messages = layers.Dense(embed_dimension)(messages)

    atom_state = layers.Add()([original_atom_state, messages])

    return atom_state, bond_state
Exemplo n.º 2
0
    def message_block(original_atom_state,
                      original_bond_state,
                      connectivity):
        """ Performs the graph-aware updates """

        atom_state = layers.LayerNormalization()(original_atom_state)
        bond_state = layers.LayerNormalization()(original_bond_state)

        source_atom = nfp.Gather()([atom_state, nfp.Slice(np.s_[:, :, 1])(connectivity)])
        target_atom = nfp.Gather()([atom_state, nfp.Slice(np.s_[:, :, 0])(connectivity)])
        
        # Edge update network
        new_bond_state = layers.Concatenate()(
            [source_atom, target_atom, bond_state])
        new_bond_state = layers.Dense(
            2*atom_features, activation='relu')(new_bond_state)
        new_bond_state = layers.Dense(atom_features)(new_bond_state)

        bond_state = layers.Add()([original_bond_state, new_bond_state])

        # message function
        source_atom = layers.Dense(atom_features)(source_atom)    
        messages = layers.Multiply()([source_atom, bond_state])
        messages = nfp.Reduce(reduction='sum')(
            [messages, nfp.Slice(np.s_[:, :, 0])(connectivity), atom_state])

        # state transition function
        messages = layers.Dense(atom_features, activation='relu')(messages)
        messages = layers.Dense(atom_features)(messages)

        atom_state = layers.Add()([original_atom_state, messages])

        return atom_state, bond_state, 
Exemplo n.º 3
0
    def build(self, input_shape):
        """ inputs = [atom_state, bond_state, connectivity]
        shape(bond_state) = [batch, num_bonds, bond_features]
        """
        super().build(input_shape)

        self.gather = nfp.Gather()
        self.slice1 = nfp.Slice(np.s_[:, :, 1])
        self.slice0 = nfp.Slice(np.s_[:, :, 0])
        self.concat = nfp.ConcatDense()
Exemplo n.º 4
0
    def build(self, input_shape):
        super().build(input_shape)

        num_features = input_shape[1][-1]

        self.gather = nfp.Gather()
        self.slice0 = nfp.Slice(np.s_[:, :, 0])
        self.slice1 = nfp.Slice(np.s_[:, :, 1])

        self.concat = nfp.ConcatDense()
        self.reduce = nfp.Reduce(reduction='sum')

        self.dense1 = layers.Dense(2 * num_features, activation='relu')
        self.dense2 = layers.Dense(num_features)
Exemplo n.º 5
0
def test_slice():
    connectivity = layers.Input(shape=[None, 2],
                                dtype=tf.int64,
                                name='connectivity')

    out0 = nfp.Slice(np.s_[:, :, 0])(connectivity)
    out1 = nfp.Slice(np.s_[:, :, 1])(connectivity)

    model = tf.keras.Model([connectivity], [out0, out1])
    inputs = np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 0]]).T
    inputs = inputs[np.newaxis, :, :]

    out = model(inputs)

    assert_allclose(out[0], inputs[:, :, 0])
    assert_allclose(out[1], inputs[:, :, 1])
Exemplo n.º 6
0
def test_reduce(smiles_inputs, method):
    preprocessor, inputs = smiles_inputs
    func = getattr(np, method)

    atom_class = layers.Input(shape=[None], dtype=tf.int64, name='atom')
    bond_class = layers.Input(shape=[None], dtype=tf.int64, name='bond')
    connectivity = layers.Input(shape=[None, 2],
                                dtype=tf.int64,
                                name='connectivity')

    atom_embed = layers.Embedding(preprocessor.atom_classes,
                                  16,
                                  mask_zero=True)(atom_class)
    bond_embed = layers.Embedding(preprocessor.bond_classes,
                                  16,
                                  mask_zero=True)(bond_class)

    reduced = nfp.Reduce(method)(
        [bond_embed,
         nfp.Slice(np.s_[:, :, 0])(connectivity), atom_embed])

    model = tf.keras.Model([atom_class, bond_class, connectivity],
                           [atom_embed, bond_embed, reduced])

    atom_state, bond_state, atom_reduced = model(
        [inputs['atom'], inputs['bond'], inputs['connectivity']])

    assert_allclose(atom_reduced[0, 0, :], func(bond_state[0, :4, :], 0))
    assert_allclose(atom_reduced[0, 1, :], func(bond_state[0, 4:8, :], 0))
    assert_allclose(atom_reduced[0, 2, :], bond_state[0, 9, :], 0)
    assert_allclose(atom_reduced[0, 3, :], bond_state[0, 10, :], 0)
    assert_allclose(atom_reduced[0, 4, :], bond_state[0, 11, :], 0)
    assert_allclose(atom_reduced[0, 5, :], bond_state[0, 12, :], 0)