def call( self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False, ): if isinstance(inputs, (tuple, list)): input_ids = inputs[0] attention_mask = inputs[1] if len(inputs) > 1 else attention_mask token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids position_ids = inputs[3] if len(inputs) > 3 else position_ids head_mask = inputs[4] if len(inputs) > 4 else head_mask inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds assert len(inputs) <= 6, "Too many inputs." elif isinstance(inputs, (dict, BatchEncoding)): input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask", attention_mask) token_type_ids = inputs.get("token_type_ids", token_type_ids) position_ids = inputs.get("position_ids", position_ids) head_mask = inputs.get("head_mask", head_mask) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) assert len(inputs) <= 6, "Too many inputs." else: input_ids = inputs if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = shape_list(input_ids) elif inputs_embeds is not None: input_shape = shape_list(inputs_embeds)[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if attention_mask is None: attention_mask = tf.fill(input_shape, 1) if token_type_ids is None: token_type_ids = tf.fill(input_shape, 0) extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) head_mask = self.get_head_mask(head_mask) hidden_states = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training) if hasattr(self, "embeddings_project"): hidden_states = self.embeddings_project(hidden_states, training=training) hidden_states = self.encoder([hidden_states, extended_attention_mask, head_mask], training=training) return hidden_states
def _linear(self, inputs): """Computes logits by running inputs through a linear layer. Args: inputs: A float32 tensor with shape [batch_size, length, hidden_size] Returns: float32 tensor with shape [batch_size, length, vocab_size]. """ batch_size = shape_list(inputs)[0] length = shape_list(inputs)[1] x = tf.reshape(inputs, [-1, self.embedding_size]) logits = tf.matmul(x, self.word_embeddings, transpose_b=True) return tf.reshape(logits, [batch_size, length, self.vocab_size])
def _embedding(self, inputs, training=False): """Applies embedding based on inputs tensor.""" input_ids, position_ids, token_type_ids, inputs_embeds = inputs if input_ids is not None: input_shape = shape_list(input_ids) else: input_shape = shape_list(inputs_embeds)[:-1] seq_length = input_shape[1] if position_ids is None: position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :] if token_type_ids is None: token_type_ids = tf.fill(input_shape, 0) if inputs_embeds is None: inputs_embeds = tf.gather(self.word_embeddings, input_ids) position_embeddings = self.position_embeddings(position_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + position_embeddings + token_type_embeddings embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings, training=training) return embeddings