Пример #1
0
def rename_length_to_memory_length(x,
                                   length_name="length",
                                   memory_length_name="memory_length"):
    return mtf.rename_dimension(x, length_name, memory_length_name)
Пример #2
0
def local_self_attention_spatial_blocks(query_antecedent,
                                        kv_channels,
                                        heads,
                                        memory_w_dim=None,
                                        mask_right=False,
                                        name=None):
    """Attention to the source position and a neighborhood to the left or right.

  The sequence is divided into blocks of length block_size.
  Attention for a given query position can only see memory positions
  less than or equal to the query position, in the corresponding block
  and the previous block.

  Args:
    query_antecedent: a mtf.Tensor with shape
      [batch, num_h_blocks, num_w_blocks, h_dim, w_dim, io_channels]
      must have the same size as query_length, but a different name.
    kv_channels: a mtf.Dimension (the size of the key and value vectors)
    heads: a mtf.Dimension (the number of heads)
    memory_w_dim: mtf Dimension, for the memory width block.
    mask_right: bool, flag specifying whether we mask out attention to the right
      for the decoder.
    name: an optional string.

  Returns:
    a Tensor of shape
        [batch, num_h_blocks, num_w_blocks, h_dim, w_dim, io_channels]

  Raises:
    ValueError: if channels or depth don't match.
  """
    with tf.variable_scope(name,
                           default_name="multihead_attention",
                           values=[query_antecedent]):

        w_dim, io_channels = query_antecedent.shape.dims[-2:]
        batch, num_w_blocks = query_antecedent.shape.dims[:2]
        q_var, k_var, v_var, o_var = multihead_attention_vars(
            query_antecedent.mesh, heads, io_channels, kv_channels,
            query_antecedent.dtype)

        # Rename dimensions for the memory height and width.
        memory_antecedent = mtf.rename_dimension(query_antecedent, w_dim.name,
                                                 memory_w_dim.name)

        # Call einsum over the query and memory to get query q, keys k and values v.
        q = mtf.einsum([query_antecedent, q_var],
                       mtf.Shape(
                           [batch, heads, num_w_blocks, w_dim, kv_channels]))
        k = mtf.einsum([memory_antecedent, k_var],
                       mtf.Shape(
                           [batch, heads, num_w_blocks, w_dim, kv_channels]))
        v = mtf.einsum([memory_antecedent, v_var],
                       mtf.Shape(
                           [batch, heads, num_w_blocks, w_dim, kv_channels]))

        # Halo exchange for memory blocks.
        if memory_w_dim is not None:
            k, v = local_1d_halo_exchange(k, v, num_w_blocks, w_dim,
                                          memory_w_dim, mask_right)

        # Calculate the causal mask to avoid peeking into the future. We compute
        # this once and reuse it for all blocks since the block_size is known.
        mask = None
        if mask_right:
            mask = attention_bias_local_block(query_antecedent.mesh, w_dim,
                                              memory_w_dim)

        output = dot_product_attention(q, k, v, mask=mask)

        return mtf.einsum([output, o_var],
                          mtf.Shape([batch, num_w_blocks, w_dim, io_channels]))
Пример #3
0
  def _sample(self, features, mesh):
    hparams = self._hparams
    (inputs_embedding_var,
     targets_embedding_var,
     softmax_var,
     positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
    if self.has_input:
      inputs = features["inputs"]
      while len(inputs.shape.as_list()) > 2:
        inputs = tf.squeeze(inputs, axis=2)
      actual_batch_size = tf.shape(inputs)[0]
      actual_length = tf.shape(inputs)[1]
      inputs = tf.pad(
          inputs, [[0, hparams.batch_size - actual_batch_size],
                   [0, hparams.max_length - actual_length]])
      inputs = self._import_to_batch_by_length(
          inputs, "inputs", mesh, hparams)
      x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
           mtf.reshape(positional_embedding_var,
                       mtf.Shape([self.length_dim, self.model_dim])))
      encoder_attention_mask = (
          mtf_layers.attention_mask_ignore_padding(
              inputs, dtype=self.activation_dtype))
      with tf.variable_scope("encoder"):
        x = self._layer_stack(x,
                              hparams.num_encoder_layers,
                              self_attention_mask=encoder_attention_mask)
      encoder_output = mtf.rename_dimension(
          x, self.length_dim.name, self.memory_length_dim.name)
      encdec_tensors = []
      for layer_num in xrange(hparams.num_decoder_layers):
        with tf.variable_scope("decoder/layer_%d/encdec_attention" % layer_num):
          q_var, k_var, v_var, o_var = mtf_layers.multihead_attention_vars(
              mesh, self.heads_dim, self.model_dim,
              self.kv_dim, self.activation_dtype)
          k = mtf.einsum(
              [encoder_output, k_var],
              mtf.Shape(
                  [self.batch_dim, self.heads_dim,
                   self.memory_length_dim, self.kv_dim]))
          v = mtf.einsum(
              [encoder_output, v_var],
              mtf.Shape(
                  [self.batch_dim, self.heads_dim,
                   self.memory_length_dim, self.kv_dim]))
        encdec_tensors.append((q_var, o_var, k, v))
      partial_targets = None
    else:
      encdec_tensors = None
      encoder_output = None
      encoder_attention_mask = None
      # Prepare partial targets.
      # In either features["inputs"] or features["targets"].
      # We force the outputs to begin with these sequences.
      partial_targets = features.get("inputs", None)
      if partial_targets is None:
        partial_targets = features.get("targets", None)
      if partial_targets is not None:
        partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2)
        partial_targets = tf.to_int32(partial_targets)
        partial_targets_batch = tf.shape(partial_targets)[0]
        partial_targets_length = tf.shape(partial_targets)[1]
        partial_targets = tf.pad(
            partial_targets, [[0, hparams.batch_size - partial_targets_batch],
                              [0, hparams.max_length - partial_targets_length]])
        partial_targets = self._import_to_batch_by_length(
            partial_targets, "partial_targets", mesh, hparams)

    if hparams.beam_size == 1:
      ids_shape = mtf.Shape([self.batch_dim, self.length_dim])
      kv_shape = mtf.Shape([self.batch_dim, self.heads_dim,
                            self.memory_length_dim, self.kv_dim])
    else:
      beam_dim = mtf.Dimension("beam", hparams.beam_size)
      ids_shape = mtf.Shape([self.batch_dim, beam_dim, self.length_dim])
      kv_shape = mtf.Shape([self.batch_dim, beam_dim, self.heads_dim,
                            self.memory_length_dim, self.kv_dim])

    initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32)
    initial_kv_states = (
        [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)]
        * (2 * hparams.num_decoder_layers))
    def logits_fn(step_num, ids, states):
      """Produce logits for this step, and new states."""
      self_attention_k = states[:hparams.num_decoder_layers]
      self_attention_v = states[hparams.num_decoder_layers:]
      ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim)
      x = (mtf.gather(targets_embedding_var, ids_this_step,
                      self.targets_vocab_dim) +
           mtf.gather(positional_embedding_var, step_num, self.max_length_dim))
      with tf.variable_scope("decoder"):
        x, new_self_attention_k, new_self_attention_v = (
            self._decoder_layer_stack_incremental(
                x,
                step_num,
                encdec_tensors,
                self_attention_k,
                self_attention_v,
                encdec_attention_mask=encoder_attention_mask))
      logits = mtf.matmul(x, softmax_var)
      return logits, new_self_attention_k + new_self_attention_v

    if hparams.beam_size == 1:
      temperature = (0.0 if hparams.sampling_method == "argmax"
                     else hparams.sampling_temp)
      return mtf_beam_search.greedy_decode(
          logits_fn,
          initial_ids,
          temperature=temperature,
          initial_states=initial_kv_states,
          forced_ids=partial_targets,
          use_tpu=hparams.use_tpu)
    else:
      if self.has_input:
        input_length = mtf.reduce_sum(
            mtf.to_float(mtf.cast(inputs, tf.bool)),
            reduced_dim=self.length_dim)
        max_input_length = mtf.reduce_max(input_length)
        decode_length = mtf.cast(
            max_input_length * hparams.decode_length_multiplier
            + hparams.decode_length_constant, tf.int32)
      else:
        decode_length = None
      beams, unused_scores = mtf_beam_search.beam_search(
          logits_fn,
          initial_ids,
          hparams.alpha,
          states=initial_kv_states,
          decode_length=decode_length,
          use_tpu=hparams.use_tpu)
      return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)
Пример #4
0
  def _mtf_model_fn(self, features, mesh):
    features = copy.copy(features)
    hparams = self._hparams
    targets = tf.to_int32(features["targets"])
    if len(targets.get_shape()) > 2:
      tf.logging.info("targets = %s" % targets)
      targets = tf.squeeze(targets, [2, 3])
    # pad targets to max_length
    def pad_to_max_length(x):
      extra_length = hparams.max_length - tf.shape(x)[1]
      x = tf.pad(x, [[0, 0], [0, extra_length]])
      x = tf.reshape(x, [hparams.batch_size, hparams.max_length])
      return x
    targets = pad_to_max_length(targets)
    for key in ["targets_segmentation", "targets_position",
                "inputs_segmentation", "inputs_position"]:
      if key in features:
        features[key] = pad_to_max_length(features[key])
    shifted_targets = common_layers.shift_right_2d(targets)

    targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams)
    shifted_targets = self._import_to_batch_by_length(
        shifted_targets, "shifted_targets", mesh, hparams)

    if "targets_segmentation" in features:
      # "Packed" dataset - keep the examples from seeing each other.
      targets_segmentation = self._import_to_batch_by_length(
          features["targets_segmentation"], "targets_segmentation",
          mesh, hparams)
      targets_position = self._import_to_batch_by_length(
          features["targets_position"], "targets_position",
          mesh, hparams)
      decoder_self_attention_mask = (
          mtf_layers.attention_mask_autoregressive(
              targets_position, dtype=self.activation_dtype) +
          mtf_layers.attention_mask_same_segment(
              targets_segmentation, dtype=self.activation_dtype))
    else:
      targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
      decoder_self_attention_mask = mtf_layers.attention_mask_autoregressive(
          targets_position, dtype=self.activation_dtype)

    def layer_prepostprocess_dropout(x):
      return mtf.dropout(
          x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
          noise_shape=mtf.Shape([self.batch_dim, self.model_dim]))

    extra_losses = []
    (inputs_embedding_var,
     targets_embedding_var,
     softmax_var,
     positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
    if self.has_input:
      inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
      inputs = pad_to_max_length(inputs)
      inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams)
      if "inputs_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        inputs_segmentation = self._import_to_batch_by_length(
            features["inputs_segmentation"], "inputs_segmentation",
            mesh, hparams)
        inputs_position = self._import_to_batch_by_length(
            features["inputs_position"], "inputs_position",
            mesh, hparams)
        encoder_self_attention_mask = (
            mtf_layers.attention_mask_same_segment(
                inputs_segmentation, dtype=self.activation_dtype))
        encoder_decoder_attention_mask = (
            mtf_layers.attention_mask_same_segment(
                targets_segmentation, inputs_segmentation,
                dtype=self.activation_dtype))
      else:
        inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
        encoder_self_attention_mask = (
            mtf_layers.attention_mask_ignore_padding(
                inputs, dtype=self.activation_dtype))
        encoder_decoder_attention_mask = encoder_self_attention_mask

      x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
           mtf.gather(positional_embedding_var, inputs_position,
                      self.max_length_dim))
      x = layer_prepostprocess_dropout(x)
      with tf.variable_scope("encoder"):
        x = self._layer_stack(x,
                              hparams.num_encoder_layers,
                              self_attention_mask=encoder_self_attention_mask,
                              losses=extra_losses)
      encoder_output = mtf.rename_dimension(
          x, self.length_dim.name, self.memory_length_dim.name)
    else:
      encoder_output = None
      encoder_decoder_attention_mask = None

    # DECODER
    x = (mtf.gather(
        targets_embedding_var, shifted_targets, self.targets_vocab_dim) +
         mtf.gather(
             positional_embedding_var, targets_position, self.max_length_dim))
    x = layer_prepostprocess_dropout(x)

    # Decoder
    with tf.variable_scope("decoder"):
      x = self._layer_stack(
          x,
          hparams.num_decoder_layers,
          encoder_output=encoder_output,
          self_attention_mask=decoder_self_attention_mask,
          encdec_attention_mask=encoder_decoder_attention_mask,
          losses=extra_losses)
    logits = mtf.matmul(x, softmax_var)
    off_value = hparams.label_smoothing / self._targets_vocab_size
    on_value = 1.0 - hparams.label_smoothing + off_value
    soft_targets = mtf.one_hot(
        targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value,
        dtype=self.activation_dtype)
    loss = mtf_layers.softmax_cross_entropy_with_logits(
        logits, soft_targets, self.targets_vocab_dim)
    weights = mtf_layers.weights_nonzero(
        targets, dtype=self.activation_dtype)
    loss = mtf.reduce_mean(loss * weights)
    for l in extra_losses:
      loss += l
    return logits, loss
Пример #5
0
def bottleneck_block(inputs,
                     filters,
                     is_training,
                     strides,
                     projection_shortcut=None,
                     row_blocks_dim=None,
                     col_blocks_dim=None):
    """Bottleneck block variant for residual networks with BN after convolutions.

  Args:
    inputs: a `mtf.Tensor` of shape
        `[batch_dim, row_blocks, col_blocks, rows, cols, in_channels]`.
    filters: `int` number of filters for the first two convolutions. Note
        that the third and final convolution will use 4 times as many filters.
    is_training: `bool` for whether the model is in training mode.
    strides: `int` block stride. If greater than 1, this block will ultimately
        downsample the input.
    projection_shortcut: `function` to use for projection shortcuts (typically
        a 1x1 convolution to match the filter dimensions). If None, no
        projection is used and the input is passed as unchanged through the
        shortcut connection.
    row_blocks_dim: a mtf.Dimension, row dimension which is
        spatially partitioned along mesh axis
    col_blocks_dim: a mtf.Dimension, row dimension which is
        spatially partitioned along mesh axis

  Returns:
    The output `Tensor` of the block.
  """
    shortcut = inputs

    filter_h_dim = mtf.Dimension("filter_height", 3)
    filter_w_dim = mtf.Dimension("filter_width", 3)
    one_h_dim = mtf.Dimension("filter_height", 1)
    one_w_dim = mtf.Dimension("filter_width", 1)

    if projection_shortcut is not None:
        filters_dim = mtf.Dimension("filtersp", filters)
        kernel = mtf.get_variable(
            inputs.mesh, "kernel",
            mtf.Shape(
                [one_h_dim, one_w_dim, inputs.shape.dims[-1], filters_dim]))
        shortcut = projection_shortcut(inputs, kernel)

    # First conv block
    filters1_dim = mtf.Dimension("filters1", filters)
    kernel1 = mtf.get_variable(
        inputs.mesh, "kernel1",
        mtf.Shape([one_h_dim, one_w_dim, inputs.shape.dims[-1], filters1_dim]))
    inputs = mtf.conv2d_with_blocks(inputs,
                                    kernel1,
                                    strides=[1, 1, 1, 1],
                                    padding="SAME",
                                    h_blocks_dim=None,
                                    w_blocks_dim=col_blocks_dim)

    # TODO(nikip): Add Dropout?
    inputs = batch_norm_relu(inputs, is_training)

    # Second conv block
    filters2_dim = mtf.Dimension("filters2", filters)
    kernel2 = mtf.get_variable(
        inputs.mesh, "kernel2",
        mtf.Shape([filter_h_dim, filter_w_dim, filters1_dim, filters2_dim]))
    inputs = mtf.conv2d_with_blocks(inputs,
                                    kernel2,
                                    strides=[1, 1, 1, 1],
                                    padding="SAME",
                                    h_blocks_dim=row_blocks_dim,
                                    w_blocks_dim=col_blocks_dim)

    inputs = batch_norm_relu(inputs, is_training)

    # Third wide conv filter block
    filters3_dim = mtf.Dimension("filters3", filters)
    filters3_kernel = mtf.get_variable(
        inputs.mesh, "wide_kernel",
        mtf.Shape([one_h_dim, one_w_dim, filters2_dim, filters3_dim]))
    inputs = mtf.conv2d_with_blocks(inputs,
                                    filters3_kernel,
                                    strides,
                                    padding="SAME",
                                    h_blocks_dim=None,
                                    w_blocks_dim=col_blocks_dim)

    inputs = batch_norm_relu(inputs, is_training, relu=False)

    # TODO(nikip): Maybe add residual with a projection?
    return mtf.relu(inputs + mtf.rename_dimension(
        shortcut, shortcut.shape.dims[-1].name, inputs.shape.dims[-1].name))
Пример #6
0
 def _my_concat(a, b):
     a = mtf.rename_dimension(a, "beam", "triple_beam")
     b = mtf.rename_dimension(b, "double_beam", "triple_beam")
     return mtf.concat([a, b], "triple_beam")