コード例 #1
0
def local_attention1d_masked_decoder(x, kv_dim, heads_dim, feedforward_dim,
                                     hparams):
    """Image Transformer decoder with local1D masked layers."""
    print(x)
    _, length_dim, model_dim = x.shape.dims
    for layer in range(hparams.num_decoder_layers):
        layer_name = "decoder_layer_%d" % layer
        with tf.variable_scope(layer_name):
            # Self attention layer
            length_per_split = mtf.tensor_dim_to_size_per_split(
                hparams.layout, hparams.mesh_shape, length_dim)
            x += layer_prepostprocess_dropout(
                mtf.layers.masked_local_attention_1d(
                    mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"),
                    kv_dim,
                    heads_dim,
                    window_size=hparams.block_length,
                    length_per_split=length_per_split,
                    name="self_att"), hparams)
            # ffn layer
            x += layer_prepostprocess_dropout(
                mtf.layers.dense_relu_dense(
                    mtf.layers.layer_norm(x, model_dim, name="layer_norm_ffn"),
                    feedforward_dim,
                    hparams.dropout,
                    dropout_broadcast_dims=[length_dim]), hparams)

    output = mtf.layers.layer_norm(x, model_dim, name="final_layer_norm")
    return output
コード例 #2
0
def local_attention1d_masked_decoder(x, kv_dim, heads_dim,
                                     feedforward_dim, hparams):
  """Image Transformer decoder with local1D masked layers."""
  print(x)
  _, length_dim, model_dim = x.shape.dims
  for layer in range(hparams.num_decoder_layers):
    layer_name = "decoder_layer_%d" % layer
    with tf.variable_scope(layer_name):
      # Self attention layer
      length_per_split = mtf.tensor_dim_to_size_per_split(
          hparams.layout, hparams.mesh_shape, length_dim)
      x += layer_prepostprocess_dropout(
          mtf.layers.masked_local_attention_1d(
              mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"),
              kv_dim,
              heads_dim,
              window_size=hparams.block_length,
              length_per_split=length_per_split,
              name="self_att"), hparams)
      # ffn layer
      x += layer_prepostprocess_dropout(
          mtf.layers.dense_relu_dense(
              mtf.layers.layer_norm(x, model_dim, name="layer_norm_ffn"),
              feedforward_dim,
              hparams.dropout,
              dropout_broadcast_dims=[length_dim]), hparams)

  output = mtf.layers.layer_norm(x, model_dim, name="final_layer_norm")
  return output
コード例 #3
0
ファイル: utils.py プロジェクト: sara-nl/PRACE-advanced-dl
def serialize_num_microbatches(batch_dim,
                               length_dim,
                               mesh_shape,
                               layout_rules,
                               tokens_per_microbatch_per_replica=None):
  """Number of microbatches per batch for serialized training.

  We want to split each training step into multiple sequential steps
  to limit memory usage.  Gradients are accumulated locally and reduced once.

  This function determines the number of microbatches per batch.
  If tokens_per_microbatch_per_replica=None, then the batch is not split.

  Args:
    batch_dim: a mtf.Dimension
    length_dim: a mtf.Dimension
    mesh_shape: an input to mtf.convert_to_shape()
    layout_rules: an input to mtf.convert_to_layout_rules()
    tokens_per_microbatch_per_replica: an optional integer, e.g. 2048
  Returns:
    an integer
  """
  if not tokens_per_microbatch_per_replica:
    return 1
  batch_per_replica = mtf.tensor_dim_to_size_per_split(
      layout_rules, mesh_shape, batch_dim)
  # number of sequences per microbatch
  microbatch_size = max(1, tokens_per_microbatch_per_replica // length_dim.size)
  # decrease microbatch_size until it is a divisor of batch_per_replica
  # This is guaranteed to stop at microbatch_size=1 if not earlier.
  while batch_per_replica % microbatch_size:
    microbatch_size -= 1
  num_microbatches = batch_per_replica // microbatch_size
  tf.logging.info(
      "serialize_num_microbatches: "
      "tokens_per_microbatch_per_replica=%d "
      "batch_dim=%s "
      "length_dim=%s "
      "batch_per_replica=%d "
      "num_microbatches=%d",
      tokens_per_microbatch_per_replica,
      batch_dim,
      length_dim,
      batch_per_replica,
      num_microbatches)
  return num_microbatches
コード例 #4
0
  def _layer_stack(self,
                   x,
                   layers,
                   encoder_output=None,
                   self_attention_mask=None,
                   encdec_attention_mask=None,
                   losses=None,
                   step_num=None,
                   encdec_tensors=None,
                   states=None):
    """Encoder or decoder stack.

    Args:
      x: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim]
      layers: an list of strings
      encoder_output: an optional mtf.Tensor with shape
        [<batch_dims>, encoder_length_dim, model_dim]
      self_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, memory_length_dim] containing values 0 or -inf.
      encdec_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, encoder_length_dim] containing values 0 or -inf.
      losses: a list to be appended-to
      step_num: an optional mtf integer Scalar (used in incrmenental mode)
      encdec_tensors: an optional list of num_layers tuples, each of the form
        (q_var, o_var, k, v), (used in incremental mode)
      states: an optional list of Tensors (used in incremental mode)
    Returns:
      a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim]
    Raises:
      ValueError: if hparams make no sense
    """
    hparams = self._hparams
    is_incremental = (step_num is not None)
    mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN)
    is_training = mode == tf.estimator.ModeKeys.TRAIN
    def layer_prepostprocess_dropout(x):
      if is_incremental:
        return x
      return mtf.dropout(
          x, is_training, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
          noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))
    num_layers = len(layers)
    num_layer_norms = num_layers + 1
    layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms)
    layer_norm_combined_var = mtf.get_variable(
        x.mesh,
        "layer_norm_scale",
        mtf.Shape([layer_norms_dim, self.model_dim]),
        initializer=tf.ones_initializer(),
        activation_dtype=x.dtype)
    layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim)
    def normalize(x):
      scale = layer_norm_vars.pop(0)
      variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim)
      return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale

    if is_incremental:
      states = list(states)
      new_states = []
    tf.logging.info("states = %s" % (states,))

    for lnum, layer_type in enumerate(layers):
      with tf.variable_scope("%s_%d" % (layer_type, lnum)):
        if layer_type == "att":
          # Self attention layer
          if is_incremental:
            y, new_k, new_v = mtf.layers.multihead_self_attention_incremental(
                normalize(x),
                prev_k=states.pop(0),
                prev_v=states.pop(0),
                step_num=step_num,
                master_dtype=self.master_dtype,
                slice_dtype=self.slice_dtype,
                name="att")
            new_states.append(new_k)
            new_states.append(new_v)
            x += y
          else:
            x += layer_prepostprocess_dropout(
                mtf.layers.multihead_attention(
                    normalize(x), None,
                    self_attention_mask, self.kv_dim, self.heads_dim,
                    is_training,
                    dropout=hparams.attention_dropout,
                    dropout_broadcast_dims=[self.length_dim],
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    name="att"))
        elif layer_type == "enc_att":
          # Encoder-Decoder attention layer
          if is_incremental:
            # Encoder-Decoder attention layer
            q_var, o_var, k, v = encdec_tensors[lnum]
            x += mtf.layers.multihead_encdec_attention_incremental(
                normalize(x),
                q_var, o_var, k, v,
                encdec_attention_mask,
                name="enc_att")
          else:
            x += layer_prepostprocess_dropout(
                mtf.layers.multihead_attention(
                    normalize(x), encoder_output,
                    encdec_attention_mask, self.kv_dim, self.heads_dim,
                    is_training,
                    dropout=hparams.attention_dropout,
                    dropout_broadcast_dims=[self.length_dim],
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    name="enc_att"))
        elif layer_type == "local_att":
          if is_incremental:
            y, new_k, new_v = mtf.layers.masked_local_attention_1d_incremental(
                normalize(x),
                prev_k=states.pop(0),
                prev_v=states.pop(0),
                step_num=step_num,
                master_dtype=self.master_dtype,
                slice_dtype=self.slice_dtype,
                name="local_att")
            new_states.append(new_k)
            new_states.append(new_v)
            x += y
          else:
            x += layer_prepostprocess_dropout(
                mtf.layers.masked_local_attention_1d(
                    normalize(x),
                    self.kv_dim, self.heads_dim, is_training,
                    window_size=hparams.local_attention_window_size,
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    length_per_split=mtf.tensor_dim_to_size_per_split(
                        hparams.layout, hparams.mesh_shape,
                        self.max_length_dim),
                    name="local_att"))
        elif layer_type == "compressed_att":
          if is_incremental:
            raise ValueError("compressed_att incremental not implemented")
          else:
            x += layer_prepostprocess_dropout(
                mtf.layers.multihead_self_attention_memory_compressed(
                    normalize(x),
                    mask_right=True,
                    compression_factor=hparams.compression_factor,
                    kv_channels=self.kv_dim,
                    heads=self.heads_dim,
                    is_training=is_training,
                    dropout=hparams.attention_dropout,
                    dropout_broadcast_dims=[self.length_dim],
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    name="compressed_att"))
        else:
          if is_incremental:
            # insert length dimension.
            x_shape = x.shape
            shape_with_length = mtf.Shape(
                x_shape.dims[:-1] + [mtf.Dimension("length", 1)]
                + x_shape.dims[-1:])
            x = mtf.reshape(x, shape_with_length)
          # ffn layer
          x += layer_prepostprocess_dropout(
              self._feedforward_layer(normalize(x), layer_type, losses=losses))
          if is_incremental:
            # remove length dimension
            x = mtf.reshape(x, x_shape)

    x = layer_prepostprocess_dropout(normalize(x))
    assert not layer_norm_vars
    if is_incremental:
      return x, new_states
    else:
      return x
コード例 #5
0
  def _layer_stack(self,
                   x,
                   layers,
                   encoder_output=None,
                   self_attention_mask=None,
                   encdec_attention_mask=None,
                   losses=None,
                   step_num=None,
                   encdec_tensors=None,
                   states=None):
    """Encoder or decoder stack.

    Args:
      x: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim]
      layers: an list of strings
      encoder_output: an optional mtf.Tensor with shape
        [<batch_dims>, encoder_length_dim, model_dim]
      self_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, memory_length_dim] containing values 0 or -inf.
      encdec_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, encoder_length_dim] containing values 0 or -inf.
      losses: a list to be appended-to
      step_num: an optional mtf integer Scalar (used in incrmenental mode)
      encdec_tensors: an optional list of num_layers tuples, each of the form
        (q_var, o_var, k, v), (used in incremental mode)
      states: an optional list of Tensors (used in incremental mode)
    Returns:
      a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim]
    Raises:
      ValueError: if hparams make no sense
    """
    hparams = self._hparams
    is_incremental = (step_num is not None)
    def layer_prepostprocess_dropout(x):
      if is_incremental:
        return x
      return mtf.dropout(
          x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
          noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))
    num_layers = len(layers)
    num_layer_norms = num_layers + 1
    layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms)
    layer_norm_combined_var = mtf.get_variable(
        x.mesh,
        "layer_norm_scale",
        mtf.Shape([layer_norms_dim, self.model_dim]),
        initializer=tf.ones_initializer(),
        activation_dtype=x.dtype)
    layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim)
    def normalize(x):
      scale = layer_norm_vars.pop(0)
      variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim)
      return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale

    if is_incremental:
      states = list(states)
      new_states = []
    tf.logging.info("states = %s" % (states,))

    for lnum, layer_type in enumerate(layers):
      with tf.variable_scope("%s_%d" % (layer_type, lnum)):
        if layer_type == "att":
          # Self attention layer
          if is_incremental:
            y, new_k, new_v = mtf.layers.multihead_self_attention_incremental(
                normalize(x),
                prev_k=states.pop(0),
                prev_v=states.pop(0),
                step_num=step_num,
                master_dtype=self.master_dtype,
                slice_dtype=self.slice_dtype,
                name="att")
            new_states.append(new_k)
            new_states.append(new_v)
            x += y
          else:
            x += layer_prepostprocess_dropout(
                mtf.layers.multihead_attention(
                    normalize(x), None,
                    self_attention_mask, self.kv_dim, self.heads_dim,
                    dropout=hparams.attention_dropout,
                    dropout_broadcast_dims=[self.length_dim],
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    name="att"))
        elif layer_type == "enc_att":
          # Encoder-Decoder attention layer
          if is_incremental:
            # Encoder-Decoder attention layer
            q_var, o_var, k, v = encdec_tensors[lnum]
            x += mtf.layers.multihead_encdec_attention_incremental(
                normalize(x),
                q_var, o_var, k, v,
                encdec_attention_mask,
                name="enc_att")
          else:
            x += layer_prepostprocess_dropout(
                mtf.layers.multihead_attention(
                    normalize(x), encoder_output,
                    encdec_attention_mask, self.kv_dim, self.heads_dim,
                    dropout=hparams.attention_dropout,
                    dropout_broadcast_dims=[self.length_dim],
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    name="enc_att"))
        elif layer_type == "local_att":
          if is_incremental:
            y, new_k, new_v = mtf.layers.masked_local_attention_1d_incremental(
                normalize(x),
                prev_k=states.pop(0),
                prev_v=states.pop(0),
                step_num=step_num,
                master_dtype=self.master_dtype,
                slice_dtype=self.slice_dtype,
                name="local_att")
            new_states.append(new_k)
            new_states.append(new_v)
            x += y
          else:
            x += layer_prepostprocess_dropout(
                mtf.layers.masked_local_attention_1d(
                    normalize(x),
                    self.kv_dim, self.heads_dim,
                    window_size=hparams.local_attention_window_size,
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    length_per_split=mtf.tensor_dim_to_size_per_split(
                        hparams.layout, hparams.mesh_shape,
                        self.max_length_dim),
                    name="local_att"))
        elif layer_type == "compressed_att":
          if is_incremental:
            raise ValueError("compressed_att incremental not implemented")
          else:
            x += layer_prepostprocess_dropout(
                mtf.layers.multihead_self_attention_memory_compressed(
                    normalize(x),
                    mask_right=True,
                    compression_factor=hparams.compression_factor,
                    kv_channels=self.kv_dim,
                    heads=self.heads_dim,
                    dropout=hparams.attention_dropout,
                    dropout_broadcast_dims=[self.length_dim],
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    name="compressed_att"))
        else:
          if is_incremental:
            # insert length dimension.
            x_shape = x.shape
            shape_with_length = mtf.Shape(
                x_shape.dims[:-1] + [mtf.Dimension("length", 1)]
                + x_shape.dims[-1:])
            x = mtf.reshape(x, shape_with_length)
          # ffn layer
          x += layer_prepostprocess_dropout(
              self._feedforward_layer(normalize(x), layer_type, losses=losses))
          if is_incremental:
            # remove length dimension
            x = mtf.reshape(x, x_shape)

    x = layer_prepostprocess_dropout(normalize(x))
    assert not layer_norm_vars
    if is_incremental:
      return x, new_states
    else:
      return x