def _upsample_molecules_to_chars(self, final_char_input_seq: tf.Tensor, full_molecules: tf.Tensor) -> tf.Tensor: """Run a shallow/low-dim transformer to get a final character encoding.""" _, char_seq_length, _ = bert_modeling.get_shape_list( final_char_input_seq) # `repeated_molecules`: [batch_size, char_seq_len, molecule_hidden_size] repeated_molecules = self._repeat_molecules( full_molecules, char_seq_length=char_seq_length) # `concat`: # [batch_size, char_seq_len, molecule_hidden_size+char_hidden_final] concat = tf.concat([final_char_input_seq, repeated_molecules], axis=-1) # `upsampled`: [batch_size, char_seq_len, hidden_size] upsampled = tf.layers.conv1d( inputs=concat, filters=self.config.hidden_size, kernel_size=self.config.upsampling_kernel_size, strides=1, padding="same", activation=bert_modeling.get_activation(self.config.hidden_act), name="conv") upsampled = bert_modeling.layer_norm(upsampled) if self._is_training: upsampled = bert_modeling.dropout(upsampled, self.config.hidden_dropout_prob) return upsampled
def _encode_final_chars( self, final_char_input_seq: tf.Tensor, char_attention_mask: tf.Tensor, full_molecules: tf.Tensor, final_seq_char_positions: Optional[tf.Tensor]) -> tf.Tensor: """Run a shallow/low-dim transformer to get a final character encoding.""" # `final_char_input_seq` is a projected version of the deep molecule BERT # stack with slice-wise resnet connections. with tf.variable_scope("final_char_encoder"): # `upsampled`: [batch_size, char_seq_len, hidden_size] upsampled = self._upsample_molecules_to_chars( final_char_input_seq, full_molecules) if final_seq_char_positions is not None: # Limit transformer query seq and attention mask to these character # positions to greatly reduce the compute cost. Typically, this is just # done for the MLM training task. # `final_seq_char_query`: [batch, final_seq_char_len, char_dim] final_seq_char_query = tf.gather(upsampled, final_seq_char_positions, batch_dims=1) char_attention_mask = tf.gather(char_attention_mask, final_seq_char_positions, batch_dims=1) else: final_seq_char_query = upsampled return bert_modeling.transformer_model( input_tensor=final_seq_char_query, input_kv_tensor=upsampled, attention_mask=char_attention_mask, hidden_size=self.config.hidden_size, num_hidden_layers=1, num_attention_heads=self.config.num_attention_heads, intermediate_size=self.config.intermediate_size, intermediate_act_fn=bert_modeling.get_activation( self.config.hidden_act), hidden_dropout_prob=self.config.hidden_dropout_prob, attention_probs_dropout_prob=( self.config.attention_probs_dropout_prob), initializer_range=self.config.initializer_range)
def _encode_initial_chars(self, char_embed_seq, char_attention_mask): """Encode characters using shallow/low dim transformer.""" with tf.variable_scope("initial_char_encoder"): return local_attention.local_transformer_model( input_tensor=char_embed_seq, attention_mask=char_attention_mask, hidden_size=self.config.hidden_size, num_hidden_layers=1, num_attention_heads=self.config.num_attention_heads, intermediate_size=self.config.intermediate_size, intermediate_act_fn=bert_modeling.get_activation( self.config.hidden_act), hidden_dropout_prob=self.config.hidden_dropout_prob, attention_probs_dropout_prob=( self.config.attention_probs_dropout_prob), initializer_range=self.config.initializer_range, always_attend_to_first_position=False, first_position_attends_to_all=False, attend_from_chunk_width=self.config.local_transformer_stride, attend_from_chunk_stride=self.config.local_transformer_stride, attend_to_chunk_width=self.config.local_transformer_stride, attend_to_chunk_stride=self.config.local_transformer_stride)