def create_attention_mask_from_input_mask(from_tensor, to_mask): """Create 3D attention mask from a 2D tensor mask. Args: from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. to_mask: int32 Tensor of shape [batch_size, to_seq_length]. Returns: float Tensor of shape [batch_size, from_seq_length, to_seq_length]. """ from_shape = tf_utils.get_shape_list(from_tensor, expected_rank=[2, 3]) batch_size = from_shape[0] from_seq_length = from_shape[1] to_shape = tf_utils.get_shape_list(to_mask, expected_rank=2) to_seq_length = to_shape[1] to_mask = tf.cast( tf.reshape(to_mask, [batch_size, 1, to_seq_length]), dtype=from_tensor.dtype) # We don't assume that `from_tensor` is a mask (although it could be). We # don't actually care if we attend *from* padding tokens (only *to* padding) # tokens so we create a tensor of all ones. # # `broadcast_ones` = [batch_size, from_seq_length, 1] broadcast_ones = tf.ones( shape=[batch_size, from_seq_length, 1], dtype=from_tensor.dtype) # Here we broadcast along two dimensions to create the mask. mask = broadcast_ones * to_mask return mask
def call(self, inputs, **kwargs): """Implements call() for the layer.""" unpacked_inputs = tf_utils.unpack_inputs(inputs) word_embeddings = unpacked_inputs[0] segment_ids = unpacked_inputs[1] column_ids = unpacked_inputs[2] row_ids = unpacked_inputs[3] prev_label_ids = unpacked_inputs[4] column_ranks = unpacked_inputs[5] inv_column_ranks = unpacked_inputs[6] numeric_relations = unpacked_inputs[7] input_shape = tf_utils.get_shape_list(word_embeddings, expected_rank=3) batch_size = input_shape[0] seq_length = input_shape[1] width = input_shape[2] output = word_embeddings token_type_ids_list = [segment_ids, column_ids, row_ids, prev_label_ids, column_ranks, inv_column_ranks, numeric_relations] token_type_embeddings_list = [self.segment_embeddings, self.column_embeddings, self.row_embeddings, self.prev_label_embeddings, self.column_ranks_embeddings, self.inv_column_ranks_embeddings, self.numeric_relations_embeddings] if self.use_type_embeddings: for i, (token_type_ids, type_embeddings) in enumerate(zip(token_type_ids_list, token_type_embeddings_list)): flat_token_type_ids = tf.reshape(token_type_ids, [-1]) one_hot_ids = tf.one_hot( flat_token_type_ids, depth=self.token_type_vocab_size[i], dtype=self.dtype) token_type_embeddings = tf.matmul( one_hot_ids, type_embeddings) token_type_embeddings = tf.reshape(token_type_embeddings, [batch_size, seq_length, width]) output += token_type_embeddings if self.use_position_embeddings: if not self.reset_position_index_per_cell: position_embeddings = tf.expand_dims( tf.slice(self.position_embeddings, [ 0, 0], [seq_length, width]), axis=0) else: col_index = segmented_tensor.IndexMap( token_type_ids_list[1], self.token_type_vocab_size[1], batch_dims=1) row_index = segmented_tensor.IndexMap( token_type_ids_list[2], self.token_type_vocab_size[2], batch_dims=1) full_index = segmented_tensor.ProductIndexMap( col_index, row_index) position = tf.expand_dims(tf.range(seq_length), axis=0) batched_position = tf.repeat( position, repeats=batch_size, axis=0) first_position_per_segment = segmented_tensor.reduce_min( batched_position, full_index)[0] first_position = segmented_tensor.gather(first_position_per_segment, full_index) position_embeddings = tf.nn.embedding_lookup(self.position_embeddings, position - first_position) output += position_embeddings output = self.output_layer_norm(output) output = self.output_dropout( output, training=kwargs.get('training', False)) return output
def call(self, inputs): """Implements call() for the layer.""" input_shape = tf_utils.get_shape_list(inputs) output = tf.nn.embedding_lookup(self.embeddings, inputs) output = tf.reshape(output, input_shape + [self.embedding_size]) return output