Example #1
0
def message_block(original_atom_state, original_bond_state, connectivity, i):

    atom_state = original_atom_state
    bond_state = original_bond_state

    source_atom = nfp.Gather()(
        [atom_state, nfp.Slice(np.s_[:, :, 1])(connectivity)])
    target_atom = nfp.Gather()(
        [atom_state, nfp.Slice(np.s_[:, :, 0])(connectivity)])

    # Edge update network
    new_bond_state = layers.Concatenate(name='concat_{}'.format(i))(
        [source_atom, target_atom, bond_state])
    new_bond_state = layers.Dense(2 * embed_dimension,
                                  activation='relu')(new_bond_state)
    new_bond_state = layers.Dense(embed_dimension)(new_bond_state)

    bond_state = layers.Add()([original_bond_state, new_bond_state])

    # message function
    source_atom = layers.Dense(embed_dimension)(source_atom)
    messages = layers.Multiply()([source_atom, bond_state])
    messages = nfp.Reduce(reduction='sum')(
        [messages,
         nfp.Slice(np.s_[:, :, 0])(connectivity), atom_state])

    # state transition function
    messages = layers.Dense(embed_dimension, activation='relu')(messages)
    messages = layers.Dense(embed_dimension)(messages)

    atom_state = layers.Add()([original_atom_state, messages])

    return atom_state, bond_state
Example #2
0
    def message_block(original_atom_state,
                      original_bond_state,
                      connectivity):
        """ Performs the graph-aware updates """

        atom_state = layers.LayerNormalization()(original_atom_state)
        bond_state = layers.LayerNormalization()(original_bond_state)

        source_atom = nfp.Gather()([atom_state, nfp.Slice(np.s_[:, :, 1])(connectivity)])
        target_atom = nfp.Gather()([atom_state, nfp.Slice(np.s_[:, :, 0])(connectivity)])
        
        # Edge update network
        new_bond_state = layers.Concatenate()(
            [source_atom, target_atom, bond_state])
        new_bond_state = layers.Dense(
            2*atom_features, activation='relu')(new_bond_state)
        new_bond_state = layers.Dense(atom_features)(new_bond_state)

        bond_state = layers.Add()([original_bond_state, new_bond_state])

        # message function
        source_atom = layers.Dense(atom_features)(source_atom)    
        messages = layers.Multiply()([source_atom, bond_state])
        messages = nfp.Reduce(reduction='sum')(
            [messages, nfp.Slice(np.s_[:, :, 0])(connectivity), atom_state])

        # state transition function
        messages = layers.Dense(atom_features, activation='relu')(messages)
        messages = layers.Dense(atom_features)(messages)

        atom_state = layers.Add()([original_atom_state, messages])

        return atom_state, bond_state, 
Example #3
0
    def build(self, input_shape):
        """ inputs = [atom_state, bond_state, connectivity]
        shape(bond_state) = [batch, num_bonds, bond_features]
        """
        super().build(input_shape)

        self.gather = nfp.Gather()
        self.slice1 = nfp.Slice(np.s_[:, :, 1])
        self.slice0 = nfp.Slice(np.s_[:, :, 0])
        self.concat = nfp.ConcatDense()
Example #4
0
def test_gather():
    in1 = layers.Input(shape=[None], dtype='float', name='data')
    in2 = layers.Input(shape=[None], dtype=tf.int64, name='indices')

    gather = nfp.Gather()([in1, in2])

    model = tf.keras.Model([in1, in2], [gather])

    data = np.random.rand(2, 10).astype(np.float32)
    indices = np.array([[2, 6, 3], [5, 1, 0]])
    out = model([data, indices])

    assert_allclose(out, np.vstack([data[0, indices[0]], data[1, indices[1]]]))
Example #5
0
    def build(self, input_shape):
        super().build(input_shape)

        num_features = input_shape[1][-1]

        self.gather = nfp.Gather()
        self.slice0 = nfp.Slice(np.s_[:, :, 0])
        self.slice1 = nfp.Slice(np.s_[:, :, 1])

        self.concat = nfp.ConcatDense()
        self.reduce = nfp.Reduce(reduction='sum')

        self.dense1 = layers.Dense(2 * num_features, activation='relu')
        self.dense2 = layers.Dense(num_features)