def add_timing_signal_1d_given_position(x,
                                        position,
                                        min_timescale=1.0,
                                        max_timescale=1.0e4):
  """Adds sinusoids of diff frequencies to a Tensor, with timing position given.

  Args:
    x: a Tensor with shape [batch, length, channels]
    position: a Tensor with shape [batch, length]
    min_timescale: a float
    max_timescale: a float

  Returns:
    a Tensor the same shape as x.
  """
  channels = common_layers.shape_list(x)[2]
  num_timescales = channels // 2
  log_timescale_increment = (
      math.log(float(max_timescale) / float(min_timescale)) /
      (tf.to_float(num_timescales) - 1))
  inv_timescales = min_timescale * tf.exp(
      tf.to_float(tf.range(num_timescales)) * -log_timescale_increment)
  scaled_time = (
      tf.expand_dims(tf.to_float(position), 2) * tf.expand_dims(
          tf.expand_dims(inv_timescales, 0), 0))
  signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=2)
  signal = tf.pad(signal, [[0, 0], [0, 0], [0, tf.mod(channels, 2)]])
  signal = common_layers.cast_like(signal, x)
  return signal
def prepare_decoder(targets, hparams):
  """Prepare decoder for images."""
  targets_shape = common_layers.shape_list(targets)
  channels = hparams.num_channels
  curr_infer_length = None

  # during training, images are [batch, IMG_LEN, IMG_LEN, 3].
  # At inference, they are [batch, curr_infer_length, 1, 1]
  if hparams.mode == tf.contrib.learn.ModeKeys.INFER:
    curr_infer_length = targets_shape[1]
    if hparams.block_raster_scan:
      assert hparams.img_len*channels % hparams.query_shape[1] == 0
      assert hparams.img_len % hparams.query_shape[0] == 0
      total_block_width = hparams.img_len*channels
      # Decoding is in block raster scan order. We divide the image into
      # hparams.query_shape blocks and then decode each block in raster scan.
      # To make that compatible with our inference pipeline, pad the target so
      # that rows is a multiple of query_shape and columns is a multiple of
      # hparams.img_len*channels
      curr_infer_length = targets_shape[1]
      block_padding_factor = total_block_width * hparams.query_shape[0]
      targets = tf.pad(targets, [
          [0, 0], [0, -curr_infer_length % block_padding_factor],
          [0, 0], [0, 0]])

      num_blocks = total_block_width // hparams.query_shape[1]
      # Reshape the image to represent blocks
      target_blocks = tf.reshape(
          targets, [targets_shape[0], -1, num_blocks, hparams.query_shape[0],
                    hparams.query_shape[1]])
      # Transpose to read the image in 2D fashion.
      targets = tf.transpose(target_blocks, [0, 1, 3, 2, 4])
    else:
      # add padding to make sure the size of targets is a multiple of img_height
      # times number of channels. This is  needed for positional encodings and
      # for doing the RGB lookup.
      padding_factor = channels * hparams.img_len
      targets = tf.pad(targets, [
          [0, 0], [0, -curr_infer_length % padding_factor], [0, 0], [0, 0]])
    targets = tf.reshape(targets,
                         [targets_shape[0], -1, hparams.img_len, channels])
  # Preprocess image
  x = prepare_image(targets, hparams, name="dec_channels")
  x_shape = common_layers.shape_list(x)
  if (hparams.dec_attention_type == AttentionType.LOCAL_2D or
      hparams.dec_attention_type == AttentionType.LOCAL_BLOCK):
    x = common_attention.right_shift_blockwise(x, hparams.query_shape)
    x = add_pos_signals(x, hparams, "dec_pos")
  else:
    # Add position signals
    x = tf.reshape(x, [targets_shape[0],
                       x_shape[1]*x_shape[2], hparams.hidden_size])
    x = common_layers.shift_right_3d(x)
    x = tf.reshape(x, [targets_shape[0],
                       x_shape[1], x_shape[2], hparams.hidden_size])
    x = add_pos_signals(x, hparams, "dec_pos")
  x = common_layers.cast_like(x, targets)
  return x, x_shape[1], x_shape[2]
def prepare_decoder(targets, hparams):
  """Prepare decoder for images."""
  targets_shape = common_layers.shape_list(targets)
  channels = hparams.num_channels
  curr_infer_length = None

  # during training, images are [batch, IMG_LEN, IMG_LEN, 3].
  # At inference, they are [batch, curr_infer_length, 1, 1]
  if hparams.mode == tf.estimator.ModeKeys.PREDICT:
    curr_infer_length = targets_shape[1]
    if hparams.block_raster_scan:
      assert hparams.img_len*channels % hparams.query_shape[1] == 0
      assert hparams.img_len % hparams.query_shape[0] == 0
      total_block_width = hparams.img_len*channels
      # Decoding is in block raster scan order. We divide the image into
      # hparams.query_shape blocks and then decode each block in raster scan.
      # To make that compatible with our inference pipeline, pad the target so
      # that rows is a multiple of query_shape and columns is a multiple of
      # hparams.img_len*channels
      curr_infer_length = targets_shape[1]
      block_padding_factor = total_block_width * hparams.query_shape[0]
      targets = tf.pad(targets, [
          [0, 0], [0, -curr_infer_length % block_padding_factor],
          [0, 0], [0, 0]])

      num_blocks = total_block_width // hparams.query_shape[1]
      # Reshape the image to represent blocks
      target_blocks = tf.reshape(
          targets, [targets_shape[0], -1, num_blocks, hparams.query_shape[0],
                    hparams.query_shape[1]])
      # Transpose to read the image in 2D fashion.
      targets = tf.transpose(target_blocks, [0, 1, 3, 2, 4])
    else:
      # add padding to make sure the size of targets is a multiple of img_height
      # times number of channels. This is  needed for positional encodings and
      # for doing the RGB lookup.
      padding_factor = channels * hparams.img_len
      targets = tf.pad(targets, [
          [0, 0], [0, -curr_infer_length % padding_factor], [0, 0], [0, 0]])
    targets = tf.reshape(targets,
                         [targets_shape[0], -1, hparams.img_len, channels])
  # Preprocess image
  x = prepare_image(targets, hparams, name="dec_channels")
  x_shape = common_layers.shape_list(x)
  if (hparams.dec_attention_type == AttentionType.LOCAL_2D or
      hparams.dec_attention_type == AttentionType.LOCAL_BLOCK):
    x = common_attention.right_shift_blockwise(x, hparams.query_shape)
    x = add_pos_signals(x, hparams, "dec_pos")
  else:
    # Add position signals
    x = tf.reshape(x, [targets_shape[0],
                       x_shape[1]*x_shape[2], hparams.hidden_size])
    x = common_layers.shift_right_3d(x)
    x = tf.reshape(x, [targets_shape[0],
                       x_shape[1], x_shape[2], hparams.hidden_size])
    x = add_pos_signals(x, hparams, "dec_pos")
  x = common_layers.cast_like(x, targets)
  return x, x_shape[1], x_shape[2]
Beispiel #4
0
   def cast_grad_tpu(g, v):
       """Should match upstream t2t
 https://github.com/tensorflow/tensor2tensor/blob/1547c25571633f828ddd74accba76d07d8d043af/tensor2tensor/utils/optimize.py#L232
 """
       if v is not None and g is not None:
           g = common_layers.cast_like(g, v)
       if self._zero_grads and g is None:
           g = tf.zeros_like(v)
       return (g, v)
  def bottom_simple(self, x, name, reuse):
    with tf.variable_scope(name, reuse=reuse):
      # Ensure the inputs are 3-D
      if len(x.get_shape()) == 4:
        x = tf.squeeze(x, axis=3)
      while len(x.get_shape()) < 3:
        x = tf.expand_dims(x, axis=-1)

      var = self._get_weights()
      x = common_layers.dropout_no_scaling(
          x, 1.0 - self._model_hparams.symbol_dropout)
      ret = common_layers.gather(var, x)
      if self._model_hparams.multiply_embedding_mode == "sqrt_depth":
        ret *= self._body_input_depth**0.5
      ret *= tf.expand_dims(
          common_layers.cast_like(tf.not_equal(x, 0), ret), -1)
      return ret
Beispiel #6
0
 def _resource_apply_dense(self, grad, handle):
     var = handle
     grad = tf.to_float(grad)
     grad_squared = tf.square(grad) + self._epsilon1
     grad_squared_mean = tf.reduce_mean(grad_squared)
     decay_rate = self._call_if_callable(self._decay_rate)
     update_scale = self._call_if_callable(self._learning_rate)
     update_scale = tf.convert_to_tensor(update_scale, name="update_scale")
     update_scale = tf.cast(update_scale,
                            grad_squared_mean.dtype.base_dtype)
     old_val = var
     if var.dtype.base_dtype == tf.bfloat16:
         old_val = tf.to_float(self._parameter_encoding.decode(old_val))
     if self._multiply_by_parameter_scale:
         update_scale *= tf.to_float(self._parameter_scale(old_val))
     # HACK: Make things dependent on grad.
     # This confounds the XLA rewriter and keeps it from fusing computations
     # across different variables.  This fusion is a bad for HBM usage, since
     # it causes the gradients to persist in memory.
     decay_rate += grad_squared_mean * 1e-30
     update_scale += grad_squared_mean * 1e-30
     # END HACK
     mixing_rate = 1.0 - decay_rate
     shape = var.get_shape().as_list()
     updates = []
     if self._should_use_factored_second_moment_estimate(shape):
         grad_squared_row_mean = tf.reduce_mean(grad_squared, -1)
         grad_squared_col_mean = tf.reduce_mean(grad_squared, -2)
         vr = self.get_slot(var, "vr")
         new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean)
         vc = self.get_slot(var, "vc")
         new_vc = (decay_rate * vc + mixing_rate * grad_squared_col_mean)
         vr_update = tf.assign(vr, new_vr, use_locking=self._use_locking)
         vc_update = tf.assign(vc, new_vc, use_locking=self._use_locking)
         updates = [vr_update, vc_update]
         long_term_mean = tf.reduce_mean(new_vr, -1, keepdims=True)
         r_factor = tf.rsqrt(new_vr / long_term_mean)
         c_factor = tf.rsqrt(new_vc)
         x = grad * tf.expand_dims(r_factor, -1) * tf.expand_dims(
             c_factor, -2)
     else:
         v = self.get_slot(var, "v")
         new_v = decay_rate * v + mixing_rate * grad_squared
         v_update = tf.assign(v, new_v, use_locking=self._use_locking)
         updates = [v_update]
         x = grad * tf.rsqrt(new_v)
     if self._clipping_threshold is not None:
         clipping_denom = tf.maximum(
             1.0,
             reduce_rms(x) / self._clipping_threshold)
         x /= clipping_denom
     subtrahend = update_scale * x
     if self._beta1:
         m = self.get_slot(var, "m")
         new_m = self._beta1 * tf.to_float(m) + (1.0 -
                                                 self._beta1) * subtrahend
         subtrahend = new_m
         new_m = common_layers.cast_like(new_m, var)
         updates.append(tf.assign(m, new_m, use_locking=self._use_locking))
     new_val = tf.to_float(old_val) - subtrahend
     if var.dtype.base_dtype == tf.bfloat16:
         new_val = self._parameter_encoding.encode(new_val,
                                                   self._quantization_noise)
     if self._simulated_quantize_bits:
         new_val = quantization.simulated_quantize(
             var - subtrahend, self._simulated_quantize_bits,
             self._quantization_noise)
     new_val = tf.cast(new_val, var.dtype)
     var_update = tf.assign(var, new_val, use_locking=self._use_locking)
     updates = [var_update] + updates
     return tf.group(*updates)
Beispiel #7
0
 def cast_grad(g, v):
   if v is not None and g is not None:
     g = common_layers.cast_like(g, v)
   if self._zero_grads and g is None:
     g = tf.zeros_like(v)
   return (g, v)
Beispiel #8
0
 def _resource_apply_dense(self, grad, handle):
   var = handle
   grad = tf.to_float(grad)
   grad_squared = tf.square(grad) + self._epsilon1
   grad_squared_mean = tf.reduce_mean(grad_squared)
   decay_rate = self._decay_rate
   update_scale = self._learning_rate
   old_val = var
   if var.dtype.base_dtype == tf.bfloat16:
     old_val = tf.to_float(self._parameter_encoding.decode(old_val))
   if self._multiply_by_parameter_scale:
     update_scale *= tf.to_float(self._parameter_scale(old_val))
   # HACK: Make things dependent on grad.
   # This confounds the XLA rewriter and keeps it from fusing computations
   # across different variables.  This fusion is a bad for HBM usage, since
   # it causes the gradients to persist in memory.
   decay_rate += grad_squared_mean * 1e-30
   update_scale += grad_squared_mean * 1e-30
   # END HACK
   mixing_rate = 1.0 - decay_rate
   shape = var.get_shape().as_list()
   updates = []
   if self._should_use_factored_second_moment_estimate(shape):
     grad_squared_row_mean = tf.reduce_mean(grad_squared, -1)
     grad_squared_col_mean = tf.reduce_mean(grad_squared, -2)
     vr = self.get_slot(var, "vr")
     new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean)
     vc = self.get_slot(var, "vc")
     new_vc = (decay_rate * vc + mixing_rate * grad_squared_col_mean)
     vr_update = tf.assign(vr, new_vr, use_locking=self._use_locking)
     vc_update = tf.assign(vc, new_vc, use_locking=self._use_locking)
     updates = [vr_update, vc_update]
     long_term_mean = tf.reduce_mean(new_vr, -1, keepdims=True)
     r_factor = tf.rsqrt(new_vr / long_term_mean)
     c_factor = tf.rsqrt(new_vc)
     x = grad * tf.expand_dims(r_factor, -1) * tf.expand_dims(c_factor, -2)
   else:
     v = self.get_slot(var, "v")
     new_v = decay_rate * v + mixing_rate * grad_squared
     v_update = tf.assign(v, new_v, use_locking=self._use_locking)
     updates = [v_update]
     x = grad * tf.rsqrt(new_v)
   if self._clipping_threshold is not None:
     clipping_denom = tf.maximum(1.0, reduce_rms(x) / self._clipping_threshold)
     x /= clipping_denom
   subtrahend = update_scale * x
   if self._beta1:
     m = self.get_slot(var, "m")
     new_m = self._beta1 * tf.to_float(m) + (1.0 - self._beta1) * subtrahend
     subtrahend = new_m
     new_m = common_layers.cast_like(new_m, var)
     updates.append(tf.assign(m, new_m, use_locking=self._use_locking))
   new_val = tf.to_float(old_val) - subtrahend
   if var.dtype.base_dtype == tf.bfloat16:
     new_val = self._parameter_encoding.encode(
         new_val, self._quantization_noise)
   if self._simulated_quantize_bits:
     new_val = quantization.simulated_quantize(
         var - subtrahend, self._simulated_quantize_bits,
         self._quantization_noise)
   var_update = tf.assign(var, new_val, use_locking=self._use_locking)
   updates = [var_update] + updates
   return tf.group(*updates)
Beispiel #9
0
 def cast_grad(g, v):
   if v is not None and g is not None:
     g = common_layers.cast_like(g, v)
   return (g, v)
Beispiel #10
0
def transformer_prepare_encoder(inputs, target_space, hparams, features=None):
  """Prepare one shard of the model for the encoder.

  Args:
    inputs: a Tensor.
    target_space: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    encoder_input: a Tensor, bottom of encoder stack
    encoder_self_attention_bias: a bias tensor for use in encoder self-attention
    encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
      attention
  """
  ishape_static = inputs.shape.as_list()
  encoder_input = inputs
  if features and "inputs_segmentation" in features:
    # Packed dataset.  Keep the examples from seeing each other.
    inputs_segmentation = features["inputs_segmentation"]
    inputs_position = features["inputs_position"]
    targets_segmentation = features["targets_segmentation"]
    if (hasattr(hparams, "unidirectional_encoder") and
        hparams.unidirectional_encoder):
      tf.logging.info("Using unidirectional encoder")
      encoder_self_attention_bias = (
          common_attention.attention_bias_lower_triangle(
              common_layers.shape_list(inputs)[1]))
    else:
      encoder_self_attention_bias = (
          common_attention.attention_bias_same_segment(
              inputs_segmentation, inputs_segmentation))
    encoder_decoder_attention_bias = (
        common_attention.attention_bias_same_segment(targets_segmentation,
                                                     inputs_segmentation))
  else:
    encoder_padding = common_attention.embedding_to_padding(encoder_input)
    ignore_padding = common_attention.attention_bias_ignore_padding(
        encoder_padding)
    if (hasattr(hparams, "unidirectional_encoder") and
        hparams.unidirectional_encoder):
      tf.logging.info("Using unidirectional encoder")
      encoder_self_attention_bias = (
          common_attention.attention_bias_lower_triangle(
              common_layers.shape_list(inputs)[1]))
    else:
      # Usual case - not a packed dataset.
      encoder_self_attention_bias = ignore_padding
    encoder_decoder_attention_bias = ignore_padding
    inputs_position = None
  if hparams.proximity_bias:
    encoder_self_attention_bias += common_attention.attention_bias_proximal(
        common_layers.shape_list(inputs)[1])
  if target_space is not None and hparams.get("use_target_space_embedding",
                                              True):
    # Append target_space_id embedding to inputs.
    emb_target_space = common_layers.embedding(
        target_space,
        32,
        ishape_static[-1],
        name="target_space_embedding",
        dtype=hparams.get("activation_dtype", "float32"))
    emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
    encoder_input += emb_target_space
  if hparams.pos == "timing":
    if inputs_position is not None:
      encoder_input = common_attention.add_timing_signal_1d_given_position(
          encoder_input, inputs_position)
    else:
      encoder_input = common_attention.add_timing_signal_1d(encoder_input)
  elif hparams.pos == "emb":
    encoder_input = common_attention.add_positional_embedding(
        encoder_input, hparams.max_length, "inputs_positional_embedding",
        inputs_position)

  encoder_self_attention_bias = common_layers.cast_like(
      encoder_self_attention_bias, encoder_input)
  encoder_decoder_attention_bias = common_layers.cast_like(
      encoder_decoder_attention_bias, encoder_input)
  return (encoder_input, encoder_self_attention_bias,
          encoder_decoder_attention_bias)
Beispiel #11
0
 def cast_grad(g, v):
   if v is not None and g is not None:
     g = common_layers.cast_like(g, v)
   return (g, v)
Beispiel #12
0
def dot_product_area_attention(q,
                               k,
                               v,
                               bias,
                               dropout_rate=0.0,
                               image_shapes=None,
                               name=None,
                               attention_image_summary=None,
                               save_weights_to=None,
                               dropout_broadcast_dims=None,
                               max_area_width=1,
                               max_area_height=1,
                               memory_height=1,
                               area_key_mode="mean",
                               area_value_mode="sum",
                               top_k_areas=0,
                               area_temperature=1.0,
                               training=True):
    """Dot-product area attention.

  Args:
    q: Tensor with shape [..., length_q, depth_k].
    k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must
      match with q.
    v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must
      match with q.
    bias: bias Tensor (see attention_bias())
    dropout_rate: a float.
    image_shapes: optional tuple of integer scalars.
      see comments for attention_image_summary()
    name: an optional string
    attention_image_summary: the callback for making image summary of attention.
    save_weights_to: an optional dictionary to capture attention weights
      for visualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    dropout_broadcast_dims: an optional list of integers less than rank of q.
      Specifies in which dimensions to broadcast the dropout decisions.
    max_area_width: the max width allowed for an area.
    max_area_height: the max height allowed for an area.
    memory_height: the height of the memory.
    area_key_mode: the mode for computing area keys, which can be "mean",
      "concat", "sum", "sample_concat", and "sample_sum".
    area_value_mode: the mode for computing area values, which can be either
      "mean", or "sum".
    top_k_areas: Use the top key areas for attention.
    area_temperature: the temperature for attention softmax.
    training: indicating if it is in the training mode.
  Returns:
    Tensor with shape [..., length_q, depth_v].
  """

    tf.logging.info(
        "dot_product_area_attention: "
        "area_h=%d, area_w=%d, mem_h=%d, "
        "area_key_mode=%s, area_value_mode=%s, "
        "area_temperature=%f", max_area_height, max_area_width, memory_height,
        area_key_mode, area_value_mode, area_temperature)
    with tf.variable_scope(name,
                           default_name="dot_product_area_attention",
                           values=[q, k, v]) as scope:
        mem_shape = common_layers.shape_list(k)
        batch_size = mem_shape[0]
        head_size = mem_shape[1]
        length = mem_shape[2]
        depth = mem_shape[3]
        k_area = compute_area_key(tf.reshape(k, [-1, length, depth]),
                                  max_area_width=max_area_width,
                                  max_area_height=max_area_height,
                                  height=memory_height,
                                  mode=area_key_mode,
                                  training=training)
        if area_value_mode == "mean":
            v_area, _, _, _, _ = compute_area_features(
                tf.reshape(v, [-1, length, depth]),
                max_area_width=max_area_width,
                max_area_height=max_area_height,
                height=memory_height)
        elif area_value_mode == "max":
            v_area, _, _ = basic_pool(tf.reshape(v, [-1, length, depth]),
                                      max_area_width=max_area_width,
                                      max_area_height=max_area_height,
                                      height=memory_height,
                                      fn=tf.reduce_max)
        elif area_value_mode == "sum":
            _, _, v_area, _, _ = compute_area_features(
                tf.reshape(v, [-1, length, depth]),
                max_area_width=max_area_width,
                max_area_height=max_area_height,
                height=memory_height)
        else:
            raise ValueError("Unsupported area value mode=%s" %
                             area_value_mode)
        k = tf.reshape(k_area, [batch_size, head_size, -1, depth])
        v = tf.reshape(v_area, [batch_size, head_size, -1, depth])
        logits = tf.matmul(q, k,
                           transpose_b=True)  # [..., length_q, length_kv]
        if bias is not None:
            bias = common_layers.cast_like(bias, logits)
            with tf.name_scope("compute_area_att_bias", values=[bias]):
                bias_shape = common_layers.shape_list(bias)
                mem_length = bias_shape[-1]
                bias_values = tf.reshape(tf.to_float(tf.less(bias, -1)),
                                         [-1, mem_length, 1])
                _, _, padding_sum, _, _ = compute_area_features(
                    bias_values,
                    max_area_width=max_area_width,
                    max_area_height=max_area_height,
                    height=memory_height)
                bias = tf.where(tf.cast(tf.to_int32(padding_sum), tf.bool),
                                tf.fill(tf.shape(padding_sum), -np.inf),
                                tf.zeros_like(padding_sum, dtype=tf.float32))
                bias = tf.reshape(
                    bias, [bias_shape[0], bias_shape[1], bias_shape[2], -1])
            logits += bias
        logits = logits / area_temperature
        weights = tf.nn.softmax(logits, name="attention_weights")
        if top_k_areas > 0:
            tf.logging.info("area_attention top_k_areas=%d", top_k_areas)
            top_k = tf.minimum(
                common_layers.shape_list(weights)[-1], top_k_areas)
            top_weights, _ = tf.nn.top_k(weights, k=top_k)
            min_values = tf.reduce_min(top_weights, -1, keepdims=True)
            weights = tf.where(tf.greater_equal(weights, min_values), weights,
                               tf.zeros_like(weights))
            weights = tf.div(weights, tf.reduce_sum(weights, -1,
                                                    keepdims=True))
        if save_weights_to is not None:
            save_weights_to[scope.name] = weights
            save_weights_to[scope.name + "/logits"] = logits
        # Drop out attention links for each head.
        weights = common_layers.dropout_with_broadcast_dims(
            weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims)
        if common_layers.should_generate_summaries(
        ) and attention_image_summary:
            attention_image_summary(weights, image_shapes)
        return tf.matmul(weights, v)
def transformer_prepare_encoder(inputs,
                                target_space,
                                hparams,
                                features=None,
                                type_ids=None,
                                num_types=None,
                                reuse_target_embedding=tf.AUTO_REUSE):
    """Prepare one shard of the model for the encoder.

  Args:
    inputs: a Tensor.
    target_space: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.
    type_ids: optional, an int64 Tensor of shape [batch, length] that allows
      for adding type embeddings, similar to positional embeddings.
    num_types: optional, an int that decides the number of types in type_ids.
    reuse_target_embedding: option to reuse variable name in the case that
      symbol modalities are reused between inputs/targets.

  Returns:
    encoder_input: a Tensor, bottom of encoder stack
    encoder_self_attention_bias: a bias tensor for use in encoder self-attention
    encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
      attention
  """
    ishape_static = inputs.shape.as_list()
    encoder_input = inputs
    if features and "inputs_segmentation" in features:
        # Packed dataset.  Keep the examples from seeing each other.
        inputs_segmentation = features["inputs_segmentation"]
        inputs_position = features["inputs_position"]
        targets_segmentation = features["targets_segmentation"]
        if (hasattr(hparams, "unidirectional_encoder")
                and hparams.unidirectional_encoder):
            tf.logging.info("Using unidirectional encoder")
            encoder_self_attention_bias = (
                common_attention.attention_bias_lower_triangle(
                    common_layers.shape_list(inputs)[1]))
        else:
            encoder_self_attention_bias = (
                common_attention.attention_bias_same_segment(
                    inputs_segmentation, inputs_segmentation))
        encoder_decoder_attention_bias = (
            common_attention.attention_bias_same_segment(
                targets_segmentation, inputs_segmentation))
    else:
        encoder_padding = common_attention.embedding_to_padding(encoder_input)
        ignore_padding = common_attention.attention_bias_ignore_padding(
            encoder_padding)
        if (hasattr(hparams, "unidirectional_encoder")
                and hparams.unidirectional_encoder):
            tf.logging.info("Using unidirectional encoder")
            encoder_self_attention_bias = (
                common_attention.attention_bias_lower_triangle(
                    common_layers.shape_list(inputs)[1]))
        else:
            # Usual case - not a packed dataset.
            encoder_self_attention_bias = ignore_padding
        encoder_decoder_attention_bias = ignore_padding
        inputs_position = None
    if hparams.proximity_bias:
        encoder_self_attention_bias += common_attention.attention_bias_proximal(
            common_layers.shape_list(inputs)[1])
    if target_space is not None and hparams.get("use_target_space_embedding",
                                                True):
        # Append target_space_id embedding to inputs.
        emb_target_space = common_layers.embedding(
            target_space,
            32,
            ishape_static[-1],
            name="target_space_embedding",
            dtype=hparams.get("activation_dtype", "float32"),
            reuse=reuse_target_embedding)
        emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
        encoder_input += emb_target_space
    if hparams.pos == "timing":
        if inputs_position is not None:
            encoder_input = common_attention.add_timing_signal_1d_given_position(
                encoder_input, inputs_position)
        else:
            encoder_input = common_attention.add_timing_signal_1d(
                encoder_input)
    elif hparams.pos == "timing_from_features":
        encoder_input = common_attention.add_timing_signals_from_features(
            encoder_input, features, hparams.position_features)
    elif hparams.pos == "emb":
        encoder_input = common_attention.add_positional_embedding(
            encoder_input, hparams.max_length, "inputs_positional_embedding",
            inputs_position)

    # Add type embeddings
    if type_ids is not None:
        if not num_types:
            raise ValueError("Need to set num_types as well.")
        encoder_input = common_attention.add_positional_embedding(
            encoder_input, num_types, "inputs_type_embedding", type_ids)

    encoder_self_attention_bias = common_layers.cast_like(
        encoder_self_attention_bias, encoder_input)
    encoder_decoder_attention_bias = common_layers.cast_like(
        encoder_decoder_attention_bias, encoder_input)
    return (encoder_input, encoder_self_attention_bias,
            encoder_decoder_attention_bias)
Beispiel #14
0
def dot_product_attention_mtsa(
    q,
    k,
    v,
    bias,
    dropout_rate=0.0,
    image_shapes=None,
    name=None,
    make_image_summary=True,
    save_weights_to=None,
    dropout_broadcast_dims=None,
    use_k_mtsa=True,
    afn_extra='none',
    afn_dot='exp',
    afn_multi='exp',
    bias_start=0.,
    bi_direction=False,
):
    """Dot-product attention.

  Args:
    q: Tensor with shape [..., length_q, depth_k].
    k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must
      match with q.
    v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must
      match with q.
    bias: bias Tensor (see attention_bias())
    dropout_rate: a float.
    image_shapes: optional tuple of integer scalars.
      see comments for attention_image_summary()
    name: an optional string
    make_image_summary: True if you want an image summary.
    save_weights_to: an optional dictionary to capture attention weights
      for visualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    dropout_broadcast_dims: an optional list of integers less than rank of q.
      Specifies in which dimensions to broadcast the dropout decisions.

  Returns:
    Tensor with shape [..., length_q, depth_v].
  """
    print("!!!!!dot_product_attention_mtsa!!!!!")
    with tf.variable_scope(name,
                           default_name="dot_product_attention",
                           values=[q, k, v]) as scope:
        # get dim
        dim_q = q.get_shape().as_list()[-1]
        dim_k = k.get_shape().as_list()[-1]
        dim_v = v.get_shape().as_list()[-1]
        # prepare
        multi_logits_scale_factor = 1. / math.sqrt(
            dim_v) if afn_multi.startswith('scaled') else 1.
        afn_extra, afn_dot, afn_multi = afn_name2fn(afn_extra), afn_name2fn(
            afn_dot), afn_name2fn(afn_multi)
        # if bias is not None:
        #   inp_mask_1d = tf.to_float(tf.equal(bias, 0.))  # bs,1,1,vl
        #   inp_mask_1d = tf.transpose(inp_mask_1d, [0, 1, 3, 2])   # bs,1,vl,1
        # else:
        #   inp_mask_1d = None

        # token2token self attention
        dot_logits = tf.matmul(q, k, transpose_b=True)  # bs,hd,ql,vl
        if bias is not None:
            bias = common_layers.cast_like(bias, dot_logits)  # 1/bs,1,ql/1,vl
            dot_logits += bias
        e_dot_logits = afn_dot(dot_logits)  # bs,hd,ql,vl
        if bi_direction:
            head_num = v.get_shape().as_list()[1]
            ql, vl = tf.shape(q)[-2], tf.shape(v)[-2]
            assert head_num is not None
            assert head_num % 2 == 0
            ones_mat = tf.ones([ql, vl], tf.float32)
            mul_mask_fw = tf.matrix_band_part(ones_mat, -1,
                                              0)  #  Lower triangular part.
            mul_mask_bw = tf.matrix_band_part(ones_mat, 0,
                                              -1)  #  Upper triangular part.
            mul_mask_fw_tile = tf.tile(tf.expand_dims(mul_mask_fw, 0),
                                       [head_num // 2, 1, 1])
            mul_mask_bw_tile = tf.tile(tf.expand_dims(mul_mask_bw, 0),
                                       [head_num // 2, 1, 1])
            mul_mask = tf.expand_dims(tf.concat(
                [mul_mask_fw_tile, mul_mask_bw_tile], axis=0),
                                      axis=0)
            e_dot_logits *= mul_mask

        # source2token self-attention
        multi_logits = multi_head_dense_layer(
            k if use_k_mtsa else v, dim_v, True,
            bias_start if afn_extra is None else 0., 'multi_logits1')
        if afn_extra is not None:  # use one extra layer for multi-dim
            multi_logits = multi_head_dense_layer(afn_extra(multi_logits),
                                                  dim_v, True, bias_start,
                                                  'multi_logits2')
        e_multi_logits = afn_multi(multi_logits *
                                   multi_logits_scale_factor)  # bs,hd,vl,vd
        # if inp_mask_1d is not None:  # use mask for exp_logits
        #   e_multi_logits *= inp_mask_1d

        # mtsa
        accum_z_deno = tf.matmul(e_dot_logits, e_multi_logits)  # bs,hd,ql,vd
        accum_z_deno = tf.where(  # in case of NaN and Inf
            tf.greater(accum_z_deno, tf.zeros_like(accum_z_deno)),
            accum_z_deno, tf.ones_like(accum_z_deno))

        # attention dropout
        e_dot_logits = common_layers.dropout_with_broadcast_dims(
            e_dot_logits,
            math.sqrt(1. - dropout_rate),
            broadcast_dims=dropout_broadcast_dims)
        e_multi_logits = common_layers.dropout_with_broadcast_dims(
            e_multi_logits,
            math.sqrt(1. - dropout_rate),
            broadcast_dims=dropout_broadcast_dims)
        rep_mul_score = v * e_multi_logits  # bs,hd,vl,vd
        accum_rep_mul_score = tf.matmul(e_dot_logits,
                                        rep_mul_score)  # bs,hd,ql,vd
        # calculate the final attention results
        attn_res = accum_rep_mul_score / accum_z_deno
        # if inp_mask_1d is not None:  # use mask for output
        #   attn_res *= inp_mask_1d

        # ============ for vis =======
        weights = e_dot_logits / (tf.reduce_sum(
            e_dot_logits, axis=-1, keepdims=True, name="attention_weights") +
                                  0.00001)
        if save_weights_to is not None:
            save_weights_to[scope.name] = weights
            save_weights_to[scope.name + "/logits"] = dot_logits
        if common_layers.should_generate_summaries() and make_image_summary:
            common_attention.attention_image_summary(weights, image_shapes)
        return attn_res