コード例 #1
0
ファイル: modeling.py プロジェクト: greck2908/language
    def _upsample_molecules_to_chars(self, final_char_input_seq: tf.Tensor,
                                     full_molecules: tf.Tensor) -> tf.Tensor:
        """Run a shallow/low-dim transformer to get a final character encoding."""
        _, char_seq_length, _ = bert_modeling.get_shape_list(
            final_char_input_seq)

        # `repeated_molecules`: [batch_size, char_seq_len, molecule_hidden_size]
        repeated_molecules = self._repeat_molecules(
            full_molecules, char_seq_length=char_seq_length)
        # `concat`:
        #     [batch_size, char_seq_len, molecule_hidden_size+char_hidden_final]
        concat = tf.concat([final_char_input_seq, repeated_molecules], axis=-1)

        # `upsampled`: [batch_size, char_seq_len, hidden_size]
        upsampled = tf.layers.conv1d(
            inputs=concat,
            filters=self.config.hidden_size,
            kernel_size=self.config.upsampling_kernel_size,
            strides=1,
            padding="same",
            activation=bert_modeling.get_activation(self.config.hidden_act),
            name="conv")
        upsampled = bert_modeling.layer_norm(upsampled)
        if self._is_training:
            upsampled = bert_modeling.dropout(upsampled,
                                              self.config.hidden_dropout_prob)
        return upsampled
コード例 #2
0
ファイル: modeling.py プロジェクト: greck2908/language
    def _encode_final_chars(
            self, final_char_input_seq: tf.Tensor,
            char_attention_mask: tf.Tensor, full_molecules: tf.Tensor,
            final_seq_char_positions: Optional[tf.Tensor]) -> tf.Tensor:
        """Run a shallow/low-dim transformer to get a final character encoding."""

        # `final_char_input_seq` is a projected version of the deep molecule BERT
        # stack with slice-wise resnet connections.
        with tf.variable_scope("final_char_encoder"):
            # `upsampled`: [batch_size, char_seq_len, hidden_size]
            upsampled = self._upsample_molecules_to_chars(
                final_char_input_seq, full_molecules)

            if final_seq_char_positions is not None:
                # Limit transformer query seq and attention mask to these character
                # positions to greatly reduce the compute cost. Typically, this is just
                # done for the MLM training task.

                # `final_seq_char_query`: [batch, final_seq_char_len, char_dim]
                final_seq_char_query = tf.gather(upsampled,
                                                 final_seq_char_positions,
                                                 batch_dims=1)
                char_attention_mask = tf.gather(char_attention_mask,
                                                final_seq_char_positions,
                                                batch_dims=1)
            else:
                final_seq_char_query = upsampled

            return bert_modeling.transformer_model(
                input_tensor=final_seq_char_query,
                input_kv_tensor=upsampled,
                attention_mask=char_attention_mask,
                hidden_size=self.config.hidden_size,
                num_hidden_layers=1,
                num_attention_heads=self.config.num_attention_heads,
                intermediate_size=self.config.intermediate_size,
                intermediate_act_fn=bert_modeling.get_activation(
                    self.config.hidden_act),
                hidden_dropout_prob=self.config.hidden_dropout_prob,
                attention_probs_dropout_prob=(
                    self.config.attention_probs_dropout_prob),
                initializer_range=self.config.initializer_range)
コード例 #3
0
 def _encode_initial_chars(self, char_embed_seq, char_attention_mask):
     """Encode characters using shallow/low dim transformer."""
     with tf.variable_scope("initial_char_encoder"):
         return local_attention.local_transformer_model(
             input_tensor=char_embed_seq,
             attention_mask=char_attention_mask,
             hidden_size=self.config.hidden_size,
             num_hidden_layers=1,
             num_attention_heads=self.config.num_attention_heads,
             intermediate_size=self.config.intermediate_size,
             intermediate_act_fn=bert_modeling.get_activation(
                 self.config.hidden_act),
             hidden_dropout_prob=self.config.hidden_dropout_prob,
             attention_probs_dropout_prob=(
                 self.config.attention_probs_dropout_prob),
             initializer_range=self.config.initializer_range,
             always_attend_to_first_position=False,
             first_position_attends_to_all=False,
             attend_from_chunk_width=self.config.local_transformer_stride,
             attend_from_chunk_stride=self.config.local_transformer_stride,
             attend_to_chunk_width=self.config.local_transformer_stride,
             attend_to_chunk_stride=self.config.local_transformer_stride)