Beispiel #1
0
    def _run(self,
             inputs,
             sequence_length=None,
             cache=None,
             memory=None,
             memory_sequence_length=None,
             step=None,
             training=None):
        # Process inputs.
        inputs *= self.num_units**0.5
        if self.position_encoder is not None:
            inputs = self.position_encoder(inputs,
                                           position=step +
                                           1 if step is not None else None)
        inputs = common.dropout(inputs, self.dropout, training=training)

        # Prepare query mask.
        mask = None
        if sequence_length is not None:
            mask = transformer.build_future_mask(
                sequence_length, maximum_length=tf.shape(inputs)[1])

        # Prepare memory mask.
        memory_mask = None
        if memory is not None:
            if not isinstance(memory, (list, tuple)):
                memory = (memory, )
        if memory_sequence_length is not None:
            if not isinstance(memory_sequence_length, (list, tuple)):
                memory_sequence_length = (memory_sequence_length, )
            memory_mask = []
            for mem, mem_length in zip(memory, memory_sequence_length):
                mem_mask = tf.sequence_mask(mem_length,
                                            maxlen=tf.shape(mem)[1],
                                            dtype=tf.float32)
                mem_mask = tf.expand_dims(mem_mask, 1)
                memory_mask.append(mem_mask)

        # Run each layer.
        new_cache = []
        for i, layer in enumerate(self.layers):
            inputs, layer_cache, attention = layer(
                inputs,
                mask=mask,
                memory=memory,
                memory_mask=memory_mask,
                cache=cache[i] if cache is not None else None,
                training=training)
            new_cache.append(layer_cache)
        outputs = self.layer_norm(inputs)
        return outputs, new_cache, attention
Beispiel #2
0
 def step(self,
          inputs,
          timestep,
          state=None,
          memory=None,
          memory_sequence_length=None,
          training=None):
   outputs, state = self.cell(inputs, state, training=training)
   outputs = common.dropout(outputs, self.dropout, training=training)
   if self.first_layer_attention:
     attention = state[0].alignments
   else:
     attention = state.alignments
   return outputs, state, attention
Beispiel #3
0
 def call(self, inputs, sequence_length=None, training=None):
   """Encodes :obj:`inputs`."""
   inputs *= self.num_units**0.5
   if self.position_encoder is not None:
     inputs = self.position_encoder(inputs)
   inputs = common.dropout(inputs, self.dropout, training=training)
   mask = None
   if sequence_length is not None:
     mask = tf.sequence_mask(sequence_length, maxlen=tf.shape(inputs)[1], dtype=tf.float32)
     mask = tf.expand_dims(mask, 1)
   for layer in self.layers:
     inputs = layer(inputs, mask=mask, training=training)
   outputs = self.layer_norm(inputs)
   return outputs, None, sequence_length
Beispiel #4
0
  def step(self,
           inputs,
           timestep,
           state=None,
           memory=None,
           memory_sequence_length=None,
           training=None):
    inputs = common.dropout(inputs, rate=self.dropout, training=training)

    new_states = []
    last_outputs, state_0 = self.cells[0](inputs, state[0])
    new_states.append(state_0)

    if memory_sequence_length is not None:
      memory_mask = tf.sequence_mask(memory_sequence_length, maxlen=tf.shape(memory)[1])
    else:
      memory_mask = None

    context, _, attention = self.multi_head_attention(
        tf.expand_dims(last_outputs, 1),
        memory=memory,
        mask=memory_mask,
        training=training)
    attention = attention[:, 0, 0]  # Use the first head for the attention vector.
    context = tf.squeeze(context, axis=1)

    for i in range(1, len(self.cells)):
      inputs = tf.concat([last_outputs, context], axis=-1)
      outputs, state_i = self.cells[i](inputs, state[i], training=training)
      new_states.append(state_i)
      outputs = common.dropout(outputs, rate=self.dropout, training=training)
      if i >= 2:
        outputs += last_outputs
      last_outputs = outputs

    final = tf.concat([last_outputs, context], -1)
    return final, tuple(new_states), attention
Beispiel #5
0
 def _embed(self, inputs, training):
     mask = tf.math.not_equal(inputs, 0)
     outputs = tf.nn.embedding_lookup(self.embedding, inputs)
     outputs = common.dropout(outputs, self.dropout, training=training)
     return outputs, mask
Beispiel #6
0
 def call(self, features, training=None):
     outputs = tf.nn.embedding_lookup(self.embedding, features["ids"])
     outputs = common.dropout(outputs, self.dropout, training=training)
     return outputs
Beispiel #7
0
    def call(self, inputs, memory=None, mask=None, cache=None, training=None):  # pylint: disable=arguments-differ
        """Runs the layer.

    Args:
      inputs: The sequence of queries. A tensor of shape :math:`[B, T_1, ...]`.
      memory: The sequence to attend. A tensor of shape :math:`[B, T_2, ...]`.
        If ``None``, computes self-attention.
      mask: A ``tf.Tensor`` applied to the dot product.
      cache: A dictionary containing pre-projected keys and values.
      training: Run in training mode.

    Returns:
      A tuple with the attention context, the updated cache and the attention
      probabilities of the first head (if :obj:`return_attention` is ``True``).
    """
        def _compute_kv(x):
            keys = self.linear_keys(x)
            keys = split_heads(keys, self.num_heads)
            values = self.linear_values(x)
            values = split_heads(values, self.num_heads)
            return keys, values

        # Compute queries.
        queries = self.linear_queries(inputs)
        queries = split_heads(queries, self.num_heads)
        queries *= (self.num_units // self.num_heads)**-0.5

        # Compute keys and values.
        if memory is None:
            keys, values = _compute_kv(inputs)
            if cache:
                keys = tf.concat([cache[0], keys], axis=2)
                values = tf.concat([cache[1], values], axis=2)
        else:
            if cache:
                if not self.linear_keys.built:
                    # Ensure that the variable names are not impacted by the tf.cond name
                    # scope if the layers have not already been built.
                    with tf.name_scope(self.linear_keys.name):
                        self.linear_keys.build(memory.shape)
                    with tf.name_scope(self.linear_values.name):
                        self.linear_values.build(memory.shape)
                keys, values = tf.cond(tf.equal(tf.shape(cache[0])[2], 0),
                                       true_fn=lambda: _compute_kv(memory),
                                       false_fn=lambda: cache)
            else:
                keys, values = _compute_kv(memory)

        cache = (keys, values)

        # Dot product attention.
        dot = tf.matmul(queries, keys, transpose_b=True)
        if mask is not None:
            mask = tf.expand_dims(tf.cast(mask, tf.float32),
                                  1)  # Broadcast on heads dimension.
            dot = tf.cast(
                tf.cast(dot, tf.float32) * mask +
                ((1.0 - mask) * tf.float32.min), dot.dtype)
        attn = tf.cast(tf.nn.softmax(tf.cast(dot, tf.float32)), dot.dtype)
        drop_attn = common.dropout(attn, self.dropout, training=training)
        heads = tf.matmul(drop_attn, values)

        # Concatenate all heads output.
        combined = combine_heads(heads)
        outputs = self.linear_output(combined)
        if self.return_attention:
            return outputs, cache, attn
        return outputs, cache
Beispiel #8
0
 def call(self, inputs, training=None):  # pylint: disable=arguments-differ
     """Runs the layer."""
     inner = self.inner(inputs)
     inner = common.dropout(inner, self.dropout, training=training)
     return self.outer(inner)
Beispiel #9
0
  def call(self, inputs, memory=None, mask=None, cache=None, training=None):  # pylint: disable=arguments-differ
    """Runs the layer.

    Args:
      inputs: The sequence of queries. A tensor of shape :math:`[B, T_1, ...]`.
      memory: The sequence to attend. A tensor of shape :math:`[B, T_2, ...]`.
        If ``None``, computes self-attention.
      mask: The dot product mask. A boolean tensor of shape :math:`[B, T_2]` or
        :math:`[B, T_1, T_2]`.
      cache: An optional tuple containing projected keys and values from the
        previous step. Tensors of shape :math:`[B, H, T_2, D / H]`.
      training: Run in training mode.

    Returns:
      A tuple with the attention context, the updated cache and the attention
      probabilities of the first head (if :obj:`return_attention` is ``True``).
    """

    def _compute_kv(x):
      keys = self.linear_keys(x)
      keys = split_heads(keys, self.num_heads)
      values = self.linear_values(x)
      values = split_heads(values, self.num_heads)
      return keys, values

    # Compute queries.
    queries = self.linear_queries(inputs)
    queries = split_heads(queries, self.num_heads)
    queries *= (self.num_units // self.num_heads)**-0.5

    # Compute keys and values.
    if memory is None:
      keys, values = _compute_kv(inputs)
      if cache:
        keys = tf.concat([cache[0], keys], axis=2)
        values = tf.concat([cache[1], values], axis=2)
    else:
      if cache:
        keys, values = tf.cond(
            tf.equal(tf.shape(cache[0])[2], 0),
            true_fn=lambda: _compute_kv(memory),
            false_fn=lambda: cache)
      else:
        keys, values = _compute_kv(memory)

    cache = (keys, values)

    # Dot product attention.
    dot = tf.matmul(queries, keys, transpose_b=True)
    if mask is not None:
      mask = tf.cast(mask, tf.float32)
      if mask.shape.rank == 2:
        mask = tf.expand_dims(mask, 1)  # Broadcast on time dimension.
      mask = tf.expand_dims(mask, 1)  # Broadcast on head dimension.
      dot = tf.cast(tf.cast(dot, tf.float32) * mask + ((1.0 - mask) * tf.float32.min), dot.dtype)
    attn = tf.cast(tf.nn.softmax(tf.cast(dot, tf.float32)), dot.dtype)
    drop_attn = common.dropout(attn, self.dropout, training=training)
    heads = tf.matmul(drop_attn, values)

    # Concatenate all heads output.
    combined = combine_heads(heads)
    outputs = self.linear_output(combined)
    if self.return_attention:
      return outputs, cache, attn
    return outputs, cache
Beispiel #10
0
    def _run(
        self,
        inputs,
        sequence_length=None,
        cache=None,
        memory=None,
        memory_sequence_length=None,
        step=None,
        training=None,
    ):
        # Process inputs.
        inputs *= self.num_units ** 0.5
        if self.position_encoder is not None:
            inputs = self.position_encoder(
                inputs, position=step + 1 if step is not None else None
            )
        inputs = common.dropout(inputs, self.dropout, training=training)

        # Prepare query mask.
        mask = None
        if step is None:
            maximum_length = tf.shape(inputs)[1]
            if sequence_length is None:
                batch_size = tf.shape(inputs)[0]
                sequence_length = tf.fill([batch_size], maximum_length)
            mask = transformer.future_mask(
                sequence_length, maximum_length=maximum_length
            )

        # Prepare memory mask.
        memory_mask = None
        if memory is not None:
            if not isinstance(memory, (list, tuple)):
                memory = (memory,)
            if memory_sequence_length is not None:
                if not isinstance(memory_sequence_length, (list, tuple)):
                    memory_sequence_length = (memory_sequence_length,)
                memory_mask = [
                    tf.sequence_mask(mem_length, maxlen=tf.shape(mem)[1])
                    for mem, mem_length in zip(memory, memory_sequence_length)
                ]
            else:
                memory_mask = tuple(None for _ in memory)

        # Run each layer.
        new_cache = []
        attention = []
        for i, layer in enumerate(self.layers):
            inputs, layer_cache, layer_attention = layer(
                inputs,
                mask=mask,
                memory=memory,
                memory_mask=memory_mask,
                cache=cache[i] if cache is not None else None,
                training=training,
            )
            attention.append(layer_attention)
            new_cache.append(layer_cache)
        outputs = self.layer_norm(inputs) if self.layer_norm is not None else inputs

        # Convert list of shape num_layers x num_sources to num_sources x num_layers
        attention = list(map(list, zip(*attention)))
        if attention:
            attention = transformer.MultiHeadAttentionReduction.reduce(
                attention[0],  # Get attention to the first source.
                self.attention_reduction,
            )
        else:
            attention = None

        return outputs, new_cache, attention
Beispiel #11
0
 def call(self, inputs, sequence_length=None, training=None):
     inputs = common.dropout(inputs, self.dropout, training=training)
     outputs, state, sequence_length = super().call(
         inputs, sequence_length=sequence_length, training=training)
     projected = self.projection(outputs)
     return (projected, state, sequence_length)
Beispiel #12
0
 def call(self, inputs, training=None):
     """Runs the layer."""
     inner = self.inner(inputs)
     inner = common.dropout(inner, self.dropout, training=training)
     return self.outer(inner)
    def call(self, inputs, memory=None, mask=None, cache=None, training=None):  # pylint: disable=arguments-differ
        """Runs the layer.

    Args:
      inputs: The sequence of queries. A tensor of shape :math:`[B, T_1, ...]`.
      memory: The sequence to attend. A tensor of shape :math:`[B, T_2, ...]`.
        If ``None``, computes self-attention.
      mask: The dot product mask. A boolean tensor of shape :math:`[B, T_2]` or
        :math:`[B, T_1, T_2]`.
      cache: An optional tuple containing projected keys and values from the
        previous step. Tensors of shape :math:`[B, H, T_2, D / H]`.
      training: Run in training mode.

    Returns:
      A tuple with the attention context, the updated cache and the attention
      probabilities of the first head (if :obj:`return_attention` is ``True``).
    """
        def _compute_kv(x):
            keys = self.linear_keys(x)
            keys = split_heads(keys, self.num_heads)
            values = self.linear_values(x)
            values = split_heads(values, self.num_heads)
            return keys, values

        # Compute queries.
        queries = self.linear_queries(inputs)
        queries = split_heads(queries, self.num_heads)
        queries *= self.num_units_per_head**-0.5

        # Compute keys and values.
        if memory is None:
            keys, values = _compute_kv(inputs)
            if cache:
                keys = tf.concat([cache[0], keys], axis=2)
                values = tf.concat([cache[1], values], axis=2)
        else:
            if cache:
                keys, values = tf.cond(tf.equal(tf.shape(cache[0])[2], 0),
                                       true_fn=lambda: _compute_kv(memory),
                                       false_fn=lambda: cache)
            else:
                keys, values = _compute_kv(memory)

        if self.maximum_relative_position is not None:
            if memory is not None:
                raise ValueError(
                    "Relative position representations only supports self-attention"
                )
            keys_length = tf.shape(keys)[2]
            relative_pos = relative_positions(keys_length,
                                              self.maximum_relative_position,
                                              with_cache=bool(cache))
            relative_repr_keys = tf.gather(self.relative_position_keys,
                                           relative_pos)
            relative_repr_values = tf.gather(self.relative_position_values,
                                             relative_pos)
        else:
            relative_repr_keys = None
            relative_repr_values = None

        cache = (keys, values)

        # Express 2D Convolution as an expanded matrix multiplication.
        # Does nothing if `num_heads_attended == 1`
        max_shift = int((self.num_attended_heads - 1) / 2)

        def conv_itself(x):
            rolls = [
                tf.roll(x, shift=shift, axis=1)
                for shift in range(max_shift, -(max_shift + 1), -1)
            ]
            return tf.concat(rolls, axis=2)

        keys = conv_itself(keys)
        values = conv_itself(values)

        # Dot product attention.
        dot = tf.matmul(queries, keys, transpose_b=True)

        if self.attention_span is not None:
            if memory is not None:
                raise ValueError("Attention Span only supports self-attention")
            batch_size = tf.shape(queries)[0]
            maximum_length = tf.shape(queries)[2]
            attention_span = tf.math.minimum(self.attention_span,
                                             maximum_length)
            head_span_mask = tf.linalg.band_part(
                tf.ones([
                    batch_size, self.num_heads, maximum_length, maximum_length
                ]), attention_span, attention_span)
            span_mask = tf.concat([head_span_mask] * self.num_attended_heads,
                                  axis=3)
            span_mask = tf.cast(span_mask, tf.float32)
            dot = tf.cast(
                tf.cast(dot, tf.float32) * span_mask +
                ((1.0 - span_mask) * tf.float32.min), dot.dtype)

        if relative_repr_keys is not None:
            dot += matmul_with_relative_representations(queries,
                                                        relative_repr_keys,
                                                        transpose_b=True)
        if mask is not None:
            mask = tf.cast(mask, tf.float32)
            if mask.shape.rank == 2:
                mask = tf.expand_dims(mask, 1)  # Broadcast on time dimension.
            mask = tf.expand_dims(mask, 1)  # Broadcast on head dimension.
            mask = tf.concat([mask] * self.num_attended_heads,
                             axis=3)  # Replicate on time dimension.
            dot = tf.cast(
                tf.cast(dot, tf.float32) * mask +
                ((1.0 - mask) * tf.float32.min), dot.dtype)
        attn = tf.cast(tf.nn.softmax(tf.cast(dot, tf.float32)), dot.dtype)
        drop_attn = common.dropout(attn, self.dropout, training=training)
        heads = tf.matmul(drop_attn, values)
        if relative_repr_values is not None:
            heads += matmul_with_relative_representations(
                drop_attn, relative_repr_values)

        # Concatenate all heads output.
        combined = combine_heads(heads)
        outputs = self.linear_output(combined)
        if self.return_attention:
            return outputs, cache, attn
        return outputs, cache
Beispiel #14
0
    def call(self, inputs, memory=None, mask=None, cache=None, training=None):
        """Runs the layer.

        Args:
          inputs: The sequence of queries. A tensor of shape :math:`[B, T_1, ...]`.
          memory: The sequence to attend. A tensor of shape :math:`[B, T_2, ...]`.
            If ``None``, computes self-attention.
          mask: The dot product mask. A boolean tensor of shape :math:`[B, T_2]` or
            :math:`[B, T_1, T_2]`.
          cache: An optional tuple containing projected keys and values from the
            previous step. Tensors of shape :math:`[B, H, T_2, D / H]`.
          training: Run in training mode.

        Returns:
          A tuple with the attention context, the updated cache and the attention
          weights (if :obj:`return_attention` is ``True``).
        """

        def _compute_kv(x):
            keys = self.linear_keys(x)
            keys = split_heads(keys, self.num_heads)
            values = self.linear_values(x)
            values = split_heads(values, self.num_heads)
            return keys, values

        # Compute queries.
        queries = self.linear_queries(inputs)
        queries = split_heads(queries, self.num_heads)
        queries *= self.num_units_per_head ** -0.5

        # Compute keys and values.
        if memory is None:
            keys, values = _compute_kv(inputs)
            if cache:
                keys = tf.concat([cache[0], keys], axis=2)
                values = tf.concat([cache[1], values], axis=2)
        else:
            if cache:
                keys, values = tf.cond(
                    tf.equal(tf.shape(cache[0])[2], 0),
                    true_fn=lambda: _compute_kv(memory),
                    false_fn=lambda: cache,
                )
            else:
                keys, values = _compute_kv(memory)

        if self.maximum_relative_position is not None:
            if memory is not None:
                raise ValueError(
                    "Relative position representations only supports self-attention"
                )
            keys_length = tf.shape(keys)[2]
            relative_pos = relative_positions(
                keys_length, self.maximum_relative_position, with_cache=bool(cache)
            )

            # Uses gather_nd instead of nn.embedding_lookup for TFLite exporting (TF Issue #42410)
            relative_pos_expanded = tf.expand_dims(relative_pos, axis=-1)
            relative_repr_keys = tf.gather_nd(
                self.relative_position_keys, relative_pos_expanded
            )
            relative_repr_values = tf.gather_nd(
                self.relative_position_values, relative_pos_expanded
            )
        else:
            relative_repr_keys = None
            relative_repr_values = None

        cache = (keys, values)

        # Dot product attention.
        dot = tf.matmul(queries, keys, transpose_b=True)
        if relative_repr_keys is not None:
            dot += matmul_with_relative_representations(
                queries, relative_repr_keys, transpose_b=True
            )
        if mask is not None:
            mask = tf.cast(mask, tf.float32)
            if mask.shape.rank == 2:
                mask = tf.expand_dims(mask, 1)  # Broadcast on time dimension.
            mask = tf.expand_dims(mask, 1)  # Broadcast on head dimension.
            dot = tf.cast(
                tf.cast(dot, tf.float32) * mask + ((1.0 - mask) * tf.float32.min),
                dot.dtype,
            )
        attn = tf.cast(tf.nn.softmax(tf.cast(dot, tf.float32)), dot.dtype)
        drop_attn = common.dropout(attn, self.dropout, training=training)
        heads = tf.matmul(drop_attn, values)
        if relative_repr_values is not None:
            heads += matmul_with_relative_representations(
                drop_attn, relative_repr_values
            )

        # Concatenate all heads output.
        combined = combine_heads(heads)
        outputs = self.linear_output(combined)
        if self.return_attention:
            return outputs, cache, attn
        return outputs, cache