Exemplo n.º 1
0
def Reformer2(input_vocab_size,
              output_vocab_size=None,
              d_model=512,
              d_ff=2048,
              d_attention_key=None,
              d_attention_value=None,
              n_encoder_layers=6,
              n_decoder_layers=6,
              n_heads=8,
              dropout=0.1,
              max_len=2048,
              encoder_attention_type=tl.SelfAttention,
              encoder_decoder_attention_type=tl.SelfAttention,
              pos_type='fixed-base',
              pos_axial_shape=(),
              pos_d_axial_embs=None,
              pos_start_from_zero_prob=1.0,
              pos_max_offset_to_add=0,
              ff_activation=tl.Relu,
              ff_use_sru=0,
              ff_chunk_size=0,
              ff_dropout=None,
              ff_sparsity=0,
              loss_sparsity_type='mult',
              loss_sparsity=0,
              loss_d_lowrank=0,
              loss_sparsity_prob=None,
              attention_chunk_size=0,
              n_layers_forget=0,
              n_decoder_attention_layers=2,
              use_bfloat16=False,
              reversible_encoder=False,
              use_two_swaps_per_encoder_block=True,
              center_layernorm=True,
              half_before_layer=None,
              double_after_layer=None,
              mode='train'):
  """Reversible transformer encoder-decoder model.

  This model expects an input pair: source, target.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    encoder_attention_type: class: attention class to use, such as SelfAttention
    encoder_decoder_attention_type: class: attention class to use, such as
      SelfAttention
    pos_type: string, the type of positional embeddings to use.
    pos_axial_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    pos_d_axial_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match pos_axial_shape, and values must sum to d_model.
    pos_start_from_zero_prob: how often to start from 0 during training,
          (if 1.0, we always start from position 0, if less, we randomize).
    pos_max_offset_to_add: maximum offset to add to positions during training
        when randomizing; this offset plus input length must still be less than
        max_len for all training examples.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    loss_sparsity_type: str, type of sparsity to used in loss layer. See
      SparseDenseWithOptions for options. None if no sparsity should be used.
    loss_sparsity: int, the sparsity for loss layer (if used)
    loss_d_lowrank: int, the dimensions for intermediate layer (if used)
    loss_sparsity_prob: float, the probability for sparse version of loss to be
      used. If None, only sparse version is used.
    attention_chunk_size: int, if > 0 run attention chunked at this size
    n_layers_forget: how often to have a forgetting block between layers
    n_decoder_attention_layers: how many attention layers in a decoder block
    use_bfloat16: whether to use bfloat16 for weights (default: False)
    reversible_encoder: whether to be reversible through the encoder
    use_two_swaps_per_encoder_block: whether to allow even number of swaps in
      the encoder
    center_layernorm: whether to use centering in LayerNorm (default) or if
      to skip it, which is known as RMS normalization.
    half_before_layer: int, half d_model and d_ff before that layer
    double_after_layer: int, double d_model and d_ff after that layer
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
  # Set default dimensions for attention head key and value sizes.
  if (d_model / 2) % n_heads != 0:
    raise ValueError(f'n_heads ({n_heads}) must divide d_model/2 ({d_model/2})')
  if d_attention_key is None:
    d_attention_key = d_model // n_heads
  if d_attention_value is None:
    d_attention_value = d_model // n_heads

  # Set values of d_model, d_ff and d_qkv for the first stage.
  d_model1, d_ff1 = d_model, d_ff
  d_attention_key1, d_attention_value1 = d_attention_key, d_attention_value
  if half_before_layer:
    d_model1, d_ff1 = d_model / 2, d_ff / 2
    d_attention_key1 = d_attention_key / 2
    d_attention_value1 = d_attention_value / 2

  # Set values of d_model, d_ff and d_qkv for the final stage.
  d_model2, d_ff2 = d_model, d_ff
  d_attention_key2, d_attention_value2 = d_attention_key, d_attention_value
  if double_after_layer:
    d_model2, d_ff2 = d_model * 2, d_ff * 2
    d_attention_key2 = d_attention_key * 2
    d_attention_value2 = d_attention_value * 2

  # Vector embeddings.
  in_encoder, out_encoder, output_vocab_size = (
      ct.EmbeddingAndPositionalEncodings(
          input_vocab_size,
          d_model1,
          mode,
          dropout,
          [-2],  # dropout_shared_axes
          max_len,
          output_vocab_size=output_vocab_size,
          pos_type=pos_type,
          pos_axial_shape=pos_axial_shape,
          pos_d_axial_embs=pos_d_axial_embs,
          pos_start_from_zero_prob=pos_start_from_zero_prob,
          pos_max_offset_to_add=pos_max_offset_to_add,
          use_bfloat16=use_bfloat16)
  )

  # pylint: disable=g-complex-comprehension
  encoder_blocks = [
      EncoderBlock(
          d_model1, d_ff1, n_heads, encoder_attention_type,
          dropout=dropout,
          ff_activation=ff_activation,
          ff_dropout=ff_dropout,
          ff_use_sru=ff_use_sru,
          ff_chunk_size=ff_chunk_size,
          ff_sparsity=ff_sparsity,
          attention_chunk_size=attention_chunk_size,
          center_layernorm=center_layernorm,
          use_bfloat16=use_bfloat16,
          use_two_swaps_per_block=use_two_swaps_per_encoder_block,
          mode=mode)
      for _ in range(n_encoder_layers)]
  # pylint: enable=g-complex-comprehension

  encoder = [     # vec_e mask_e tok_e tok_d tok_d
      tl.ReversibleSelect([0, 0]),     # vec_e1 vec_e2 mask_e tok_e tok_d tok_d
      _ReversibleSerialForget(encoder_blocks, d_model1, n_layers_forget)
  ]
  if not reversible_encoder:
    encoder += [
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),
        tl.Dense(d_model1, use_bfloat16=use_bfloat16),
        tl.LayerNorm(),
    ]
  encoder = tl.Serial(encoder)
  if mode == 'predict':
    encoder = tl.Cache(encoder)

  decoder_blocks = []

  if isinstance(encoder_decoder_attention_type, (tuple, list)):
    assert n_decoder_layers % len(encoder_decoder_attention_type) == 0
  else:
    encoder_decoder_attention_type = [encoder_decoder_attention_type]
  for layer_idx in range(n_decoder_layers):
    layer_attention_type = encoder_decoder_attention_type[
        layer_idx % len(encoder_decoder_attention_type)]
    # Grow d_model, d_ff, and d_qkv if requested.
    d_m, d_f, d_k, d_v = d_model1, d_ff1, d_attention_key1, d_attention_value1
    if half_before_layer and layer_idx >= half_before_layer:
      d_m, d_f, d_k, d_v = d_model, d_ff, d_attention_key, d_attention_value
    if double_after_layer and layer_idx > double_after_layer:
      d_m, d_f, d_k, d_v = d_model2, d_ff2, d_attention_key2, d_attention_value2
    decoder_block = DecoderBlock(
        d_m, d_f, d_k, d_v, n_heads,
        attention_type=layer_attention_type,
        dropout=dropout,
        ff_activation=ff_activation,
        ff_dropout=ff_dropout,
        ff_use_sru=ff_use_sru,
        ff_chunk_size=ff_chunk_size,
        ff_sparsity=ff_sparsity,
        attention_chunk_size=attention_chunk_size,
        n_attention_layers=n_decoder_attention_layers,
        center_layernorm=center_layernorm,
        use_bfloat16=use_bfloat16,
        mode=mode)
    decoder_blocks.append(decoder_block)
    if half_before_layer and layer_idx == half_before_layer - 1:
      decoder_blocks.append(tl.ReversibleConcatenatePair())
    if double_after_layer and layer_idx == double_after_layer:
      decoder_blocks.append(tl.ReversibleConcatenatePair())

  dense_loss_layer = tl.SparseDenseWithOptions(
      output_vocab_size,
      d_input=d_model2,
      sparsity_type=loss_sparsity_type,
      sparsity=loss_sparsity,
      d_lowrank=loss_d_lowrank,
      prob_sparse=loss_sparsity_prob,
      use_bfloat16=use_bfloat16,
      mode=mode)

  # Layers to merge encoder and decoder, see below for details.
  if reversible_encoder:
    encdec_layers = [
        tl.ReversibleSelect([0, 1, 4, 2, 3]),  # vec_e vec_d mask_e tok_e tok_d
        t2.ConcatWithPadding2(mode=mode),      # vec_ed vec_ed tok_e tok_d
    ]
  else:
    encdec_layers = [
        tl.ReversibleSelect([0, 3, 1, 2]),     # vec_e vec_d mask_e tok_e tok_d
        t2.ConcatWithPadding(mode=mode),       # vec_ed tok_e tok_d
        tl.ReversibleSelect([0, 0]),           # vec_ed vec_ed tok_e tok_d
    ]

  # Assemble and return the model.
  return tl.Serial(
      # Input: encoder_side_tokens, decoder_side_tokens
      # Copy decoder tokens for use in loss.
      tl.Select([0, 0, 0, 1, 1]),                # tok_e tok_e tok_e tok_d tok_d

      # Embed in and out tokens; done together as weights may be shared.
      tl.Parallel(in_encoder, [], [],            # vec_e tok_e tok_e vec_d tok_d
                  [tl.ShiftRight(mode=mode), out_encoder]),

      tl.Parallel([], [tl.PaddingMask(),
                       tl.Fn('Squeeze',
                             lambda x: jnp.squeeze(x, (1, 2)), n_out=1)]),
      #                                         # vec_e mask_e tok_e vec_d tok_d

      # Encode.
      encoder,                                  # vec_e mask_e tok_e vec_d tok_d

      # Concat encoder and decoder, given encoder mask.
      encdec_layers,

      # Run decoder blocks.
      _ReversibleSerialForget(decoder_blocks, d_model2,
                              n_layers_forget),    # vec_ed1 vec_ed2 tok_e tok_d
      tl.Fn('XYAvg',
            lambda x, y: (x + y) / 2.0),           # vec_ed tok_e tok_d
      tl.LayerNorm(),                              # vec_ed tok_e tok_d

      # Separate out the encoder part from the concatenated vector.
      tl.Select([0, 1, 2, 2]),                        # vec_ed tok_e tok_d tok_d
      t2.StripFromConcatenateWithPadding(mode=mode),  # vec_d tok_d

      # Map to output vocab.
      dense_loss_layer,  # vec_d tok_d
  )
Exemplo n.º 2
0
def ReformerLM(vocab_size,
               d_model=512,
               d_ff=2048,
               d_attention_key=64,
               d_attention_value=64,
               n_layers=6,
               n_heads=8,
               dropout=0.1,
               max_len=2048,
               attention_type=tl.SelfAttention,
               axial_pos_shape=(),
               d_axial_pos_embs=None,
               ff_activation=tl.FastGelu,
               ff_use_sru=0,
               ff_chunk_size=0,
               ff_sparsity=0,
               loss_sparsity_type='mult',
               loss_sparsity=0,
               loss_d_lowrank=0,
               loss_sparsity_prob=None,
               attention_chunk_size=0,
               mode='train'):
  """Reversible transformer language model (only uses a decoder, no encoder).

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    attention_type: class: attention class to use, such as SelfAttention.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    loss_sparsity_type: str, type of sparsity to used in loss layer. See
      SparseDenseWithOptions for options. None if no sparsity should be used.
    loss_sparsity: int, the sparsity for loss layer (if used)
    loss_d_lowrank: int, the dimensions for intermediate layer (if used)
    loss_sparsity_prob: float, the probability for sparse version of loss to be
      used. If None, only sparse version is used.
    attention_chunk_size: int, if > 0 run attention chunked at this size
    mode: str: 'train', 'eval', or 'predict'

  Returns:
    the layer.
  """
  positional_encoding = ct.PositionalEncoder(
      mode, dropout, max_len, axial_pos_shape, d_axial_pos_embs)

  positional_embedder = [
      tl.Embedding(vocab_size, d_model),
      tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),  # pylint: disable=no-value-for-parameter
      positional_encoding,
  ]

  decoder_blocks = []

  if isinstance(attention_type, (tuple, list)):
    assert n_layers % len(attention_type) == 0
  else:
    attention_type = [attention_type]
  for layer_idx in range(n_layers):
    layer_attention_type = attention_type[layer_idx % len(attention_type)]
    decoder_block = DecoderBlock(
        d_model, d_ff, d_attention_key, d_attention_value, n_heads,
        attention_type=layer_attention_type,
        dropout=dropout,
        ff_activation=ff_activation,
        ff_dropout=dropout,
        ff_use_sru=ff_use_sru,
        ff_chunk_size=ff_chunk_size,
        ff_sparsity=ff_sparsity,
        attention_chunk_size=attention_chunk_size,
        mode=mode)
    decoder_blocks.append(decoder_block)

  dense_loss_layer = tl.SparseDenseWithOptions(
      vocab_size,
      d_input=d_model,
      sparsity_type=loss_sparsity_type,
      sparsity=loss_sparsity,
      d_lowrank=loss_d_lowrank,
      prob_sparse=loss_sparsity_prob,
      mode=mode)

  return tl.Serial(
      tl.ShiftRight(mode=mode),
      positional_embedder,
      tl.Dup(),
      tl.ReversibleSerial(decoder_blocks),
      tl.Concatenate(),
      # TODO(kitaev): Test whether dropout should go before or after the
      # LayerNorm, and whether dropout broadcasting is needed here.
      tl.LayerNorm(),
      tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),  # pylint: disable=no-value-for-parameter
      dense_loss_layer,
  )
Exemplo n.º 3
0
def Reformer2(input_vocab_size,
              output_vocab_size=None,
              d_model=512,
              d_ff=2048,
              d_attention_key=None,
              d_attention_value=None,
              n_encoder_layers=6,
              n_decoder_layers=6,
              n_heads=8,
              dropout=0.1,
              max_len=2048,
              encoder_attention_type=tl.SelfAttention,
              encoder_decoder_attention_type=tl.SelfAttention,
              axial_pos_shape='fixed-base',
              d_axial_pos_embs=None,
              ff_activation=tl.Relu,
              ff_use_sru=0,
              ff_chunk_size=0,
              ff_dropout=None,
              ff_sparsity=0,
              loss_sparsity_type='mult',
              loss_sparsity=0,
              loss_d_lowrank=0,
              loss_sparsity_prob=None,
              attention_chunk_size=0,
              n_layers_forget=0,
              n_decoder_attention_layers=2,
              use_bfloat16=False,
              mode='train'):
  """Reversible transformer encoder-decoder model.

  This model expects an input pair: source, target.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    encoder_attention_type: class: attention class to use, such as SelfAttention
    encoder_decoder_attention_type: class: attention class to use, such as
      SelfAttention
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    loss_sparsity_type: str, type of sparsity to used in loss layer. See
      SparseDenseWithOptions for options. None if no sparsity should be used.
    loss_sparsity: int, the sparsity for loss layer (if used)
    loss_d_lowrank: int, the dimensions for intermediate layer (if used)
    loss_sparsity_prob: float, the probability for sparse version of loss to be
      used. If None, only sparse version is used.
    attention_chunk_size: int, if > 0 run attention chunked at this size
    n_layers_forget: how often to have a forgetting block between layers
    n_decoder_attention_layers: how many attention layers in a decoder block
    use_bfloat16: whether to use bfloat16 for weights (default: False)
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
  # Set default dimensions for attention head key and value sizes.
  if d_attention_key is None:
    if d_model % n_heads != 0:
      raise ValueError(f'n_heads ({n_heads}) must divide d_model ({d_model})')
    d_attention_key = d_model // n_heads
  if d_attention_value is None:
    if d_model % n_heads != 0:
      raise ValueError(f'n_heads ({n_heads}) must divide d_model ({d_model})')
    d_attention_value = d_model // n_heads

  # Vector embeddings.
  in_encoder, out_encoder, output_vocab_size = (
      ct.EmbeddingAndPositionalEncodings(
          input_vocab_size,
          d_model,
          mode,
          dropout,
          [-2],  # dropout_shared_axes
          max_len,
          output_vocab_size=output_vocab_size,
          axial_pos_shape=axial_pos_shape,
          d_axial_pos_embs=d_axial_pos_embs,
          use_bfloat16=use_bfloat16)
  )

  # pylint: disable=g-complex-comprehension
  encoder_blocks = [
      EncoderBlock(
          d_model, d_ff, n_heads, encoder_attention_type,
          dropout=dropout,
          ff_activation=ff_activation,
          ff_dropout=ff_dropout,
          ff_use_sru=ff_use_sru,
          ff_chunk_size=ff_chunk_size,
          ff_sparsity=ff_sparsity,
          attention_chunk_size=attention_chunk_size,
          use_bfloat16=use_bfloat16,
          mode=mode)
      for _ in range(n_encoder_layers)]
  # pylint: enable=g-complex-comprehension

  encoder = tl.Serial([                # vec_e mask_e tok_e tok_d tok_d
      tl.Dup(),                        # vec_e1 vec_e2 mask_e tok_e tok_d tok_d
      _ReversibleSerialForget(encoder_blocks, d_model, n_layers_forget),
      tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),
      tl.Dense(d_model, use_bfloat16=use_bfloat16),
      tl.LayerNorm(),
  ])
  if mode == 'predict':
    encoder = tl.Cache(encoder)

  decoder_blocks = []

  if isinstance(encoder_decoder_attention_type, (tuple, list)):
    assert n_decoder_layers % len(encoder_decoder_attention_type) == 0
  else:
    encoder_decoder_attention_type = [encoder_decoder_attention_type]
  for layer_idx in range(n_decoder_layers):
    layer_attention_type = encoder_decoder_attention_type[
        layer_idx % len(encoder_decoder_attention_type)]
    decoder_block = DecoderBlock(
        d_model, d_ff, d_attention_key, d_attention_value, n_heads,
        attention_type=layer_attention_type,
        dropout=dropout,
        ff_activation=ff_activation,
        ff_dropout=ff_dropout,
        ff_use_sru=ff_use_sru,
        ff_chunk_size=ff_chunk_size,
        ff_sparsity=ff_sparsity,
        attention_chunk_size=attention_chunk_size,
        n_attention_layers=n_decoder_attention_layers,
        use_bfloat16=use_bfloat16,
        mode=mode)
    decoder_blocks.append(decoder_block)

  dense_loss_layer = tl.SparseDenseWithOptions(
      output_vocab_size,
      d_input=d_model,
      sparsity_type=loss_sparsity_type,
      sparsity=loss_sparsity,
      d_lowrank=loss_d_lowrank,
      prob_sparse=loss_sparsity_prob,
      use_bfloat16=use_bfloat16,
      mode=mode)

  # Assemble and return the model.
  return tl.Serial(
      # Input: encoder_side_tokens, decoder_side_tokens
      # Copy decoder tokens for use in loss.
      tl.Select([0, 0, 0, 1, 1]),                # tok_e tok_e tok_e tok_d tok_d

      # Embed in and out tokens; done together as weights may be shared.
      tl.Parallel(in_encoder, [], [],            # vec_e tok_e tok_e vec_d tok_d
                  [tl.ShiftRight(mode=mode), out_encoder]),

      tl.Parallel([], [tl.PaddingMask(),
                       tl.Fn('Squeeze',
                             lambda x: jnp.squeeze(x, (1, 2)), n_out=1)]),
      #                                         # vec_e mask_e tok_e vec_d tok_d

      # Encode.
      encoder,                                  # vec_e mask_e tok_e vec_d tok_d

      # Decode.
      tl.Select([3, 0, 1, 2]),                 #  vec_d vec_e mask_e tok_e tok_d

      # Concat encoder and decoder, given encoder mask.
      tl.Select([1, 0]),                       # vec_e vec_d mask_e tok_e tok_d
      t2.ConcatWithPadding(mode=mode),         # vec_ed tok_e tok_d

      # Run (encoder and) decoder blocks.
      tl.Dup(),                                    # vec_ed1 vec_ed2 tok_e tok_d
      _ReversibleSerialForget(decoder_blocks, d_model,
                              n_layers_forget),    # vec_ed1 vec_ed2 tok_e tok_d
      tl.Fn('XYAvg',
            lambda x, y: (x + y) / 2.0),           # vec_ed tok_e tok_d
      tl.LayerNorm(),                              # vec_ed tok_e tok_d

      # Separate out the encoder part from the concatenated vector.
      tl.Select([0, 1, 2, 2]),                        # vec_ed tok_e tok_d tok_d
      t2.StripFromConcatenateWithPadding(mode=mode),  # vec_d tok_d

      # Map to output vocab.
      dense_loss_layer,  # vec_d tok_d
  )
Exemplo n.º 4
0
    def test_loss_layer_timing(self):
        all_settings = [
            # The first run is sometimes slower, less reliable.
            {
                'output': 32000,
                'input': 2048,
                'prob': None,
                'type': None,
                'sparsity': 0,
                'lowrank': 0,
                'use_bias': False
            },
            {
                'output': 32000,
                'input': 2048,
                'prob': None,
                'type': None,
                'sparsity': 0,
                'lowrank': 0,
                'use_bias': False
            },
            {
                'output': 32000,
                'input': 2048,
                'prob': None,
                'type': 'einsum',
                'sparsity': 0,
                'lowrank': 0,
                'use_bias': False
            },
            {
                'output': 32000,
                'input': 2048,
                'prob': None,
                'type': 'mult',
                'sparsity': 2,
                'lowrank': 0,
                'use_bias': False
            },
            {
                'output': 32000,
                'input': 2048,
                'prob': None,
                'type': None,
                'sparsity': 0,
                'lowrank': 0,
                'use_bias': True
            },
            {
                'output': 32000,
                'input': 2048,
                'prob': None,
                'type': 'einsum',
                'sparsity': 0,
                'lowrank': 0,
                'use_bias': True
            },
            {
                'output': 32000,
                'input': 2048,
                'prob': None,
                'type': 'mult',
                'sparsity': 2,
                'lowrank': 0,
                'use_bias': True
            },
        ]

        messages = []
        for settings in all_settings:
            pred_model = tl.SparseDenseWithOptions(
                n_units=settings['output'],
                d_input=settings['input'],
                sparsity_type=settings['type'],
                sparsity=settings['sparsity'],
                d_lowrank=settings['lowrank'],
                prob_sparse=settings['prob'],
                use_bias=settings['use_bias'],
                mode='predict',
            )
            pred_model = tl.Accelerate(pred_model)

            shape1l = shapes.ShapeDtype((1, settings['input']))
            pred_model.init(input_signature=shape1l)
            inputs = np.ones((1, settings['input']))

            total_time = 0.0
            for counter in range(-50, 100):
                start_time = time.time()
                y = pred_model(inputs)
                self.assertEqual(y.shape, (1, settings['output']))
                elapsed_time = time.time() - start_time
                if counter >= 0:
                    total_time += elapsed_time

            message = (
                '\n\nParams: %d Settings: %s\nTime for 100 tokens: %.4f s\n\n\n'
                % (_size_of_model(pred_model), settings, total_time))
            messages.append(message)
            print(message)

        print('Final results (recap):')
        for message in messages:
            print(message)
Exemplo n.º 5
0
def ConfigurableTransformer(input_vocab_size,
                            output_vocab_size=None,
                            d_model=512,
                            d_ff=2048,
                            n_encoder_layers=6,
                            n_decoder_layers=6,
                            n_heads=8,
                            max_len=2048,
                            dropout=0.1,
                            dropout_shared_axes=None,
                            mode='train',
                            ff_activation=tl.Relu,
                            ff_dropout=0.1,
                            ff_chunk_size=0,
                            ff_use_sru=0,
                            ff_sparsity=0,
                            ff_sparsity_type='1inN',
                            loss_sparsity_type='mult',
                            loss_sparsity=0,
                            loss_d_lowrank=0,
                            loss_sparsity_prob=None,
                            attention_chunk_size=0,
                            encoder_attention_type=tl.Attention,
                            encoder_decoder_attention_type=tl.CausalAttention,
                            axial_pos_shape=None,
                            d_axial_pos_embs=None):
    """Returns a full Transformer model.

  This model is an encoder-decoder that performs tokenized string-to-string
  ("source"-to-"target") transduction:

    - inputs (2):

        - source: rank 2 tensor representing a batch of text strings via token
          IDs plus padding markers; shape is (batch_size, sequence_length). The
          tensor elements are integers in `range(input_vocab_size)`, and `0`
          values mark padding positions.

        - target: rank 2 tensor representing a batch of text strings via token
          IDs plus padding markers; shape is (batch_size, sequence_length). The
          tensor elements are integers in `range(output_vocab_size)`, and `0`
          values mark padding positions.

    - output: rank 3 tensor representing a batch of log-probability
      distributions for each sequence position over possible token IDs;
      shape is (batch_size, sequence_length, `vocab_size`).

  An example use would be to translate (tokenized) sentences from English to
  German.

  Args:
    input_vocab_size: Input vocabulary size -- each element of the input tensor
      should be an integer in `range(vocab_size)`. These integers typically
      represent token IDs from a vocabulary-based tokenizer.
    output_vocab_size: If specified, gives the vocabulary size for the targets;
      if None, then input and target integers (token IDs) are assumed to come
      from the same vocabulary.
    d_model: Final dimension of tensors at most points in the model, including
      the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each encoder
      and decoder block.
    n_encoder_layers: Number of encoder blocks.
    n_decoder_layers: Number of decoder blocks.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value when
      applying dropout within an encoder/decoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing
      along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful
      way to save memory and apply consistent masks to activation vectors at
      different sequence positions.
    mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder
      block will include dropout; else, it will pass all values through
      unaltered.
    ff_activation: Type of activation function at the end of each
      encoder/decoder block; must be an activation-type subclass of `Layer`.
    ff_dropout: Stochastic rate (probability) for dropping an activation value
      when applying dropout after the FF dense layer.
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers
      in addition to the feed-forward block (second int specifies sru size)
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    ff_sparsity_type: string, if ff_sparsity >0,
      use SparseFF if ff_sparsity_type=`'1inN'` and
      use BlockSparseFF if ff_sparsity_type=`'Block'`
    loss_sparsity_type: str, type of sparsity to used in loss layer. See
      SparseDenseWithOptions for options. None if no sparsity should be used.
    loss_sparsity: int, the sparsity for loss layer (if used)
    loss_d_lowrank: int, the dimensions for intermediate layer (if used)
    loss_sparsity_prob: float, the probability for sparse version of loss to be
      used. If None, only sparse version is used.
    attention_chunk_size: int, if > 0 run attention chunked at this size
    encoder_attention_type: The attention layer to use for the encoder part.
    encoder_decoder_attention_type: The attention layer to use for the
      encoder-decoder attention.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.

  Returns:
    A Transformer model as a layer that maps from a source-target tokenized
    text pair to activations over a vocab set.
  """
    in_encoder, out_encoder, output_vocab_size = (
        EmbeddingAndPositionalEncodings(input_vocab_size,
                                        d_model,
                                        mode,
                                        dropout,
                                        dropout_shared_axes,
                                        max_len,
                                        output_vocab_size=output_vocab_size,
                                        axial_pos_shape=axial_pos_shape,
                                        d_axial_pos_embs=d_axial_pos_embs))

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                     mode, ff_activation, ff_dropout, ff_chunk_size,
                     ff_use_sru, ff_sparsity, ff_sparsity_type,
                     attention_chunk_size, encoder_attention_type)
        for i in range(n_encoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm())
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    # pylint: disable=g-complex-comprehension
    encoder_decoder_blocks = [
        EncoderDecoderBlock(d_model, d_ff, n_heads, dropout,
                            dropout_shared_axes, mode, ff_activation,
                            ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity,
                            ff_sparsity_type, attention_chunk_size,
                            encoder_decoder_attention_type)
        for i in range(n_decoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 1, 1]),  # tok_e tok_d tok_d

        # Encode.
        tl.Branch([], tl.PaddingMask()),  # tok_e masks ..... .....
        encoder,  # vec_e ..... ..... .....

        # Decode.
        tl.Select([2, 1, 0]),  # tok_d masks vec_e .....
        tl.ShiftRight(mode=mode),  # tok_d ..... ..... .....
        out_encoder,  # vec_d ..... ..... .....
        tl.Branch([], tl.EncoderDecoderMask()),  # vec_d masks ..... .....
        encoder_decoder_blocks,  # vec_d masks ..... .....
        tl.LayerNorm(),  # vec_d ..... ..... .....

        # Map to output vocab.
        tl.Select([0], n_in=3),  # vec_d tok_d
        tl.SparseDenseWithOptions(  # vec_d .....
            output_vocab_size,
            d_input=d_model,
            sparsity_type=loss_sparsity_type,
            sparsity=loss_sparsity,
            d_lowrank=loss_d_lowrank,
            prob_sparse=loss_sparsity_prob,
            mode=mode),
    )
Exemplo n.º 6
0
def ConfigurableTransformerLM(vocab_size,
                              d_model=512,
                              d_ff=2048,
                              n_layers=6,
                              n_heads=8,
                              max_len=2048,
                              dropout=0.1,
                              dropout_shared_axes=None,
                              mode='train',
                              ff_activation=tl.Relu,
                              ff_dropout=0.1,
                              ff_chunk_size=0,
                              ff_use_sru=0,
                              ff_sparsity=0,
                              ff_sparsity_type='1inN',
                              loss_sparsity_type='mult',
                              loss_sparsity=0,
                              loss_d_lowrank=0,
                              loss_sparsity_prob=None,
                              attention_chunk_size=0,
                              attention_type=tl.CausalAttention,
                              axial_pos_shape=None,
                              d_axial_pos_embs=None):
    """Returns a Transformer language model.

  This model performs autoregressive language modeling:

    - input: rank 2 tensor representing a batch of text strings via token IDs
      plus padding markers; shape is (batch_size, sequence_length). The tensor
      elements are integers in `range(vocab_size)`, and `0` values mark padding
      positions.

    - output: rank 3 tensor representing a batch of log-probability
      distributions for each sequence position over possible token IDs;
      shape is (batch_size, sequence_length, `vocab_size`).

  This model uses only the decoder part of the overall Transformer.

  Args:
    vocab_size: Input vocabulary size -- each element of the input tensor should
      be an integer in `range(vocab_size)`. These integers typically represent
      token IDs from a vocabulary-based tokenizer.
    d_model: Final dimension of tensors at most points in the model, including
      the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each encoder
      block.
    n_layers: Number of encoder blocks. Each block includes attention, dropout,
      residual, feed-forward (`Dense`), and activation layers.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value when
      applying dropout within an encoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing
      along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful
      way to save memory and apply consistent masks to activation vectors at
      different sequence positions.
    mode: If `'predict'`, use fast inference. If `'train'`, each encoder block
      will include dropout; else, it will pass all values through unaltered.
    ff_activation: Type of activation function at the end of each encoder block;
      must be an activation-type subclass of `Layer`.
    ff_dropout: Stochastic rate (probability) for dropping an activation value
      when applying dropout after the FF dense layer.
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers
      in addition to the feed-forward block (second int specifies sru size)
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    ff_sparsity_type: string, if ff_sparsity >0,
      use SparseFF if ff_sparsity_type=`'1inN'` and
      use BlockSparseFF if ff_sparsity_type=`'Block'`
    loss_sparsity_type: string, type of sparsity to used in loss layer. See
      SparseDenseWithOptions for options. None if no sparsity should be used.
    loss_sparsity: int, the sparsity for loss layer (if used)
    loss_d_lowrank: int, the dimensions for intermediate layer (if used)
    loss_sparsity_prob: float, the probability for sparse version of loss to be
      used. If None, only sparse version is used.
    attention_chunk_size: int, if > 0 run attention chunked at this size
    attention_type: The attention layer to use for the decoder part.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.

  Returns:
    A Transformer language model as a layer that maps from a tensor of tokens
    to activations over a vocab set.
  """
    positional_encoder = [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
        PositionalEncoder(mode, dropout, max_len, axial_pos_shape,
                          d_axial_pos_embs)
    ]

    # pylint: disable=g-complex-comprehension
    decoder_blocks = [
        DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                     mode, ff_activation, ff_dropout, ff_chunk_size,
                     ff_use_sru, ff_sparsity, ff_sparsity_type,
                     attention_chunk_size, attention_type)
        for i in range(n_layers)
    ]
    # pylint: enable=g-complex-comprehension

    # Assemble and return the model.
    return tl.Serial(  # tokens (or chunked tuple of tokens)
        tl.ShiftRight(mode=mode),  # toks
        positional_encoder,  # vecs
        decoder_blocks,  # vecs
        tl.LayerNorm(),  # vecs
        tl.SparseDenseWithOptions(  # vecs
            vocab_size,
            d_input=d_model,
            sparsity_type=loss_sparsity_type,
            sparsity=loss_sparsity,
            d_lowrank=loss_d_lowrank,
            prob_sparse=loss_sparsity_prob,
            mode=mode),
    )