def call(self, inputs):
        from_tensor = inputs[0]
        to_mask = inputs[1]
        from_shape = tf_utils.get_shape_list(from_tensor, expected_rank=[2, 3])
        batch_size = from_shape[0]
        from_seq_length = from_shape[1]

        to_shape = tf_utils.get_shape_list(to_mask, expected_rank=2)
        to_seq_length = to_shape[1]

        to_mask = tf.cast(tf.reshape(to_mask, [batch_size, 1, to_seq_length]),
                          dtype=from_tensor.dtype)

        # We don't assume that `from_tensor` is a mask (although it could be). We
        # don't actually care if we attend *from* padding tokens (only *to* padding)
        # tokens so we create a tensor of all ones.
        #
        # `broadcast_ones` = [batch_size, from_seq_length, 1]
        broadcast_ones = tf.ones(shape=[batch_size, from_seq_length, 1],
                                 dtype=from_tensor.dtype)

        # Here we broadcast along two dimensions to create the mask.
        mask = broadcast_ones * to_mask

        return mask
Example #2
0
    def call(self, inputs, past_length=0):
        """Implements call() for the layer."""
        if self._use_dynamic_slicing:
            input_shape = tf_utils.get_shape_list(inputs, expected_rank=3)
            seq_length = input_shape[1]
            width = input_shape[2]

            # If input = (3 x 5 x 7) (batch x sequence x width) and past_length = 0 ,
            # output is equivalent to tf.gather([0,1,2,3,4], self._position_embeddings)

            # If input = (3 x 5 x 7) (batch x sequence x width) and past_length = 2 ,
            # output is equivalent to tf.gather([2,3,4], self._position_embeddings)

            # tf.shape(position_embeddings)[0] = seq_length - past_length

            position_embeddings = tf.expand_dims(
                tf.slice(self._position_embeddings, [0, 0],
                         [seq_length, width])[past_length:],
                axis=0,
            )
        else:
            position_embeddings = tf.expand_dims(self._position_embeddings,
                                                 axis=0)

        return position_embeddings
Example #3
0
    def call(self, inputs):
        from_tensor = inputs
        from_shape = tf_utils.get_shape_list(from_tensor, expected_rank=[2, 3])
        batch_size = from_shape[0]
        from_seq_length = from_shape[1]

        # 2D Lower Triangular Mask
        from_mask = attention_mask_square(from_seq_length)

        # Replicate 2D `N` times
        mask = tf.ones([batch_size, 1, 1]) * from_mask

        return mask
Example #4
0
    def call(self, inputs):
        input_shape = tf_utils.get_shape_list(inputs, expected_rank=2)
        input_shape.append(self._embedding_width)
        flat_inputs = tf.reshape(inputs, [-1])
        if self._use_one_hot:
            one_hot_data = tf.one_hot(flat_inputs, depth=self._vocab_size, dtype=self._dtype)
            embeddings = tf.matmul(one_hot_data, self.embeddings)
        else:
            # CHANGED
            # embeddings = tf.gather(self.embeddings, flat_inputs)
            embeddings = tf.identity(tf.gather(self.embeddings, flat_inputs))
        embeddings = tf.reshape(embeddings, input_shape)

        return embeddings
 def merge_attention_heads(x):
     batch, n_heads, sequence, feature_length = tf_utils.get_shape_list(x)
     return tf.reshape(tf.transpose(x, [0, 2, 1, 3]),
                       [batch, sequence, n_heads * feature_length])
Example #6
0
 def split_states(self, x, n_heads):
     """Reshape the last dimension of x into [n, x.shape[-1]/n]."""
     batch, sequence, width = tf_utils.get_shape_list(x)
     return tf.reshape(x, [batch, sequence, n_heads, width // n_heads])
    def call(self, inputs, cache_key=None, cache_value=None, seed=None):
        """Implements a multi-headed attention layer from from_tensor to to_tensor.

        Args:
          from_tensor: float Tensor of shape [batch_size, from_seq_length,
            from_width]
          to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
          attention_mask: (optional) int32 Tensor of shape [batch_size,
            from_seq_length, to_seq_length]. The values should be 1 or 0. The
            attention scores will effectively be set to -infinity for any positions
            in the mask that are 0, and will be unchanged for positions that are 1.
          band_mask: (optional) int32 Tensor of shape [batch_size, 1,
            from_seq_length//from_block_size-4, from_block_size, 3*to_block_size].
            The values should be 1 or 0. The attention scores will effectively be
            set to -infinity for any positions in the mask that are 0, and will be
            unchanged for positions that are 1.
          from_mask: (optional) int32 Tensor of shape [batch_size, 1,
            from_seq_length, 1]. The values should be 1 or 0. The
            attention scores will effectively be set to -infinity for any positions
            in the mask that are 0, and will be unchanged for positions that are 1.
          to_mask: (optional) int32 Tensor of shape [batch_size, 1, 1,
            to_seq_length]. The values should be 1 or 0. The
            attention scores will effectively be set to -infinity for any positions
            in the mask that are 0, and will be unchanged for positions that are 1.
          from_blocked_mask: (optional) int32 Tensor of shape [batch_size,
            from_seq_length//from_block_size, from_block_size].
            Same as from_mask, just reshaped.
          to_blocked_mask: (optional) int32 Tensor of shape [batch_size,
            to_seq_length//to_block_size, to_block_size].
            Same as to_mask, just reshaped.
          cache: (Used during prediction) A dictionary with tensors containing
            results of previous attentions. The dictionary must have the items:
                {"k": tensor with shape
                      [batch_size, max_len, num_attention_heads, size_per_head],
                 "v": tensor with shape
                      [batch_size, max_len, num_attention_heads, size_per_head]}
          decode_i: (Used during prediction) current location of decoding
          training: Boolean indicating whether the call is training or inference.

        Returns:
          float Tensor of shape [batch_size, from_seq_length, num_attention_heads,
            size_per_head].

        Raises:
          ValueError: Any of the arguments or tensor shapes are invalid.
          NotImplementedError: For unknown attention type.
        """

        from_tensor = inputs[0]
        to_tensor = inputs[1]
        input_mask = inputs[2]

        # Scalar dimensions referenced here:
        #   B = batch size (number of sequences)
        #   F = `from_tensor` sequence length
        #   T = `to_tensor` sequence length
        #   N = `num_attention_heads`
        #   H = `size_per_head`
        # `query_tensor` = [B, F, N ,H]
        query_tensor = self._query_dense(from_tensor)

        # `key_tensor` = [B, T, N, H]
        key_tensor = self._key_dense(to_tensor)

        # `value_tensor` = [B, T, N, H]
        value_tensor = self._value_dense(to_tensor)

        # Transpose to [B, N, T, H]
        c = tf.transpose(query_tensor, [0, 2, 1, 3])
        key_tensor = tf.transpose(key_tensor, [0, 2, 1, 3])
        value_tensor = tf.transpose(value_tensor, [0, 2, 1, 3])

        # Prepare Necessary masks here

        batch_size, encoder_length, hidden_size = tf_utils.get_shape_list(
            from_tensor, expected_rank=3)  # tf.shape(from_tensor)
        from_seq_length = to_seq_length = encoder_length

        encoder_block_size = self.from_block_size
        blocked_encoder_mask = tf.reshape(
            input_mask, (batch_size, encoder_length // encoder_block_size,
                         encoder_block_size))

        # TensorShape([2, 1, 4096, 1])
        encoder_from_mask = tf.reshape(input_mask,
                                       (batch_size, 1, encoder_length, 1))
        # TensorShape([2, 1, 1, 4096])
        encoder_to_mask = tf.reshape(input_mask,
                                     (batch_size, 1, 1, encoder_length))

        # create band padding
        # attention_mask = None
        # TensorShape([2, 1, 60, 64, 192])
        band_mask = create_band_mask_from_inputs(blocked_encoder_mask,
                                                 blocked_encoder_mask)

        context_layer = bigbird_block_sparse_attention(
            query_tensor,
            key_tensor,
            value_tensor,
            band_mask,
            encoder_from_mask,
            encoder_to_mask,
            blocked_encoder_mask,
            blocked_encoder_mask,
            self._num_heads,
            self.num_rand_blocks,
            self._head_size,
            batch_size,
            from_seq_length,
            to_seq_length,
            self.from_block_size,
            self.to_block_size,
            seed,
        )
        return tf.reshape(
            context_layer,
            (batch_size, from_seq_length, hidden_size)), cache_key, cache_value