Ejemplo n.º 1
0
def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None):
    """Complete attention layer with preprocessing."""
    separabilities = [hparams.separability, hparams.separability]
    if hparams.separability < 0:
        separabilities = [hparams.separability - 1, hparams.separability]
    targets_timed = common_layers.subseparable_conv_block(
        common_layers.add_timing_signal(targets_shifted),
        hparams.hidden_size, [((1, 1), (5, 1)), ((4, 1), (5, 1))],
        normalizer_fn=norm_fn,
        padding="LEFT",
        separabilities=separabilities,
        name="targets_time")
    if hparams.attention_type == "transformer":
        targets_timed = tf.squeeze(targets_timed, 2)
        target_shape = tf.shape(targets_timed)
        targets_segment = tf.zeros([target_shape[0], target_shape[1]])
        target_attention_bias = common_attention.attention_bias(
            targets_segment, targets_segment, lower_triangular=True)
        inputs_attention_bias = tf.zeros([
            tf.shape(inputs_encoded)[0], hparams.num_heads,
            tf.shape(targets_segment)[1],
            tf.shape(inputs_encoded)[1]
        ])

        qv = common_attention.multihead_attention(targets_timed,
                                                  None,
                                                  target_attention_bias,
                                                  hparams.hidden_size,
                                                  hparams.hidden_size,
                                                  hparams.hidden_size,
                                                  hparams.num_heads,
                                                  hparams.attention_dropout,
                                                  name="self_attention")
        qv = common_attention.multihead_attention(qv,
                                                  inputs_encoded,
                                                  inputs_attention_bias,
                                                  hparams.hidden_size,
                                                  hparams.hidden_size,
                                                  hparams.hidden_size,
                                                  hparams.num_heads,
                                                  hparams.attention_dropout,
                                                  name="encdec_attention")
        return tf.expand_dims(qv, 2)
    elif hparams.attention_type == "simple":
        targets_with_attention = common_layers.simple_attention(targets_timed,
                                                                inputs_encoded,
                                                                bias=bias)
        return norm_fn(targets_shifted + targets_with_attention,
                       name="attn_norm")
Ejemplo n.º 2
0
def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None):
  """Complete attention layer with preprocessing."""
  separabilities = [hparams.separability, hparams.separability]
  if hparams.separability < 0:
    separabilities = [hparams.separability - 1, hparams.separability]
  targets_timed = common_layers.subseparable_conv_block(
      common_layers.add_timing_signal(targets_shifted),
      hparams.hidden_size, [((1, 1), (5, 1)), ((4, 1), (5, 1))],
      normalizer_fn=norm_fn,
      padding="LEFT",
      separabilities=separabilities,
      name="targets_time")
  if hparams.attention_type == "transformer":
    targets_timed = tf.squeeze(targets_timed, 2)
    target_shape = tf.shape(targets_timed)
    targets_segment = tf.zeros([target_shape[0], target_shape[1]])
    target_attention_bias = common_attention.attention_bias(
        targets_segment, targets_segment, lower_triangular=True)
    inputs_attention_bias = tf.zeros([
        tf.shape(inputs_encoded)[0], hparams.num_heads,
        tf.shape(targets_segment)[1],
        tf.shape(inputs_encoded)[1]
    ])

    qv = common_attention.multihead_attention(
        targets_timed,
        None,
        target_attention_bias,
        hparams.hidden_size,
        hparams.hidden_size,
        hparams.hidden_size,
        hparams.num_heads,
        hparams.attention_dropout,
        name="self_attention")
    qv = common_attention.multihead_attention(
        qv,
        inputs_encoded,
        inputs_attention_bias,
        hparams.hidden_size,
        hparams.hidden_size,
        hparams.hidden_size,
        hparams.num_heads,
        hparams.attention_dropout,
        name="encdec_attention")
    return tf.expand_dims(qv, 2)
  elif hparams.attention_type == "simple":
    targets_with_attention = common_layers.simple_attention(
        targets_timed, inputs_encoded, bias=bias)
    return norm_fn(targets_shifted + targets_with_attention, name="attn_norm")