Example #1
0
 def test_state(self):
     model = tl.Parallel(tl.Dense(3), tl.Dense(5))
     self.assertIsInstance(model.state, tuple)
     self.assertLen(model.state, 2)
Example #2
0
def BERT(
    d_model=768,
    vocab_size=30522,
    max_len=512,
    type_vocab_size=2,
    n_heads=12,
    d_ff=3072,
    n_layers=12,
    head=None,
    init_checkpoint=None,
    mode='eval',
):
    """BERT (default hparams are for bert-base-uncased)."""
    layer_norm_eps = 1e-12
    d_head = d_model // n_heads

    word_embeddings = tl.Embedding(d_model, vocab_size)
    type_embeddings = tl.Embedding(d_model, type_vocab_size)
    position_embeddings = tl.PositionalEncoding(max_len, mode=mode)
    embeddings = [
        tl.Select([0, 1, 0], n_in=3),  # Drops 'idx' input.
        tl.Parallel(word_embeddings, type_embeddings, [
            tl.PaddingMask(),
            tl.Fn('Squeeze', lambda x: np.squeeze(x, (1, 2)), n_out=1)
        ]),
        tl.Add(),
        position_embeddings,
        tl.LayerNorm(epsilon=layer_norm_eps),
    ]

    encoder = []
    for _ in range(n_layers):
        attn = tl.SelfAttention(n_heads=n_heads,
                                d_qk=d_head,
                                d_v=d_head,
                                bias=True,
                                masked=True,
                                mode=mode)
        feed_forward = [tl.Dense(d_ff), tl.Gelu(), tl.Dense(d_model)]
        encoder += [
            tl.Select([0, 1, 1]),  # Save a copy of the mask
            tl.Residual(attn, AddBias()),  # pylint: disable=no-value-for-parameter
            tl.LayerNorm(epsilon=layer_norm_eps),
            tl.Residual(*feed_forward),
            tl.LayerNorm(epsilon=layer_norm_eps),
        ]

    encoder += [tl.Select([0], n_in=2)]  # Drop the mask

    pooler = [
        tl.Fn('', lambda x: (x[:, 0, :], x), n_out=2),
        tl.Dense(d_model),
        tl.Tanh(),
    ]

    init_checkpoint = init_checkpoint if mode == 'train' else None
    bert = PretrainedBERT(embeddings + encoder + pooler,
                          init_checkpoint=init_checkpoint)

    if head is not None:
        bert = tl.Serial(bert, head())

    return bert
Example #3
0
def Reformer2(input_vocab_size,
              output_vocab_size=None,
              d_model=512,
              d_ff=2048,
              d_attention_key=None,
              d_attention_value=None,
              n_encoder_layers=6,
              n_decoder_layers=6,
              n_heads=8,
              dropout=0.1,
              max_len=2048,
              encoder_attention_type=tl.SelfAttention,
              encoder_decoder_attention_type=tl.SelfAttention,
              pos_type='fixed-base',
              pos_axial_shape=(),
              pos_d_axial_embs=None,
              pos_start_from_zero_prob=1.0,
              pos_max_offset_to_add=0,
              ff_activation=tl.Relu,
              ff_use_sru=0,
              ff_chunk_size=0,
              ff_dropout=None,
              ff_sparsity=0,
              loss_sparsity_type='mult',
              loss_sparsity=0,
              loss_d_lowrank=0,
              loss_sparsity_prob=None,
              attention_chunk_size=0,
              n_layers_forget=0,
              forget_dense=True,
              n_decoder_attention_layers=2,
              use_bfloat16=False,
              reversible_encoder=False,
              use_two_swaps_per_encoder_block=True,
              center_layernorm=True,
              half_before_layer=None,
              double_after_layer=None,
              mode='train'):
    """Reversible transformer encoder-decoder model.

  If input_vocab_size is not None, this model expects an input pair: source,
  target. Otherwise, it expects a triple: embedded_source, mask, target.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    encoder_attention_type: class: attention class to use, such as SelfAttention
    encoder_decoder_attention_type: class: attention class to use, such as
      SelfAttention
    pos_type: string, the type of positional embeddings to use.
    pos_axial_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    pos_d_axial_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match pos_axial_shape, and values must sum to d_model.
    pos_start_from_zero_prob: how often to start from 0 during training,
          (if 1.0, we always start from position 0, if less, we randomize).
    pos_max_offset_to_add: maximum offset to add to positions during training
        when randomizing; this offset plus input length must still be less than
        max_len for all training examples.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    loss_sparsity_type: str, type of sparsity to used in loss layer. See
      SparseDenseWithOptions for options. None if no sparsity should be used.
    loss_sparsity: int, the sparsity for loss layer (if used)
    loss_d_lowrank: int, the dimensions for intermediate layer (if used)
    loss_sparsity_prob: float, the probability for sparse version of loss to be
      used. If None, only sparse version is used.
    attention_chunk_size: int, if > 0 run attention chunked at this size
    n_layers_forget: how often to have a forgetting block between layers
    forget_dense: whether to use Dense or no-op (Serial) as a forget layer.
    n_decoder_attention_layers: how many attention layers in a decoder block
    use_bfloat16: whether to use bfloat16 for weights (default: False)
    reversible_encoder: whether to be reversible through the encoder
    use_two_swaps_per_encoder_block: whether to allow even number of swaps in
      the encoder
    center_layernorm: whether to use centering in LayerNorm (default) or if
      to skip it, which is known as RMS normalization.
    half_before_layer: int, half d_model and d_ff before that layer
    double_after_layer: int, double d_model and d_ff after that layer
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    # Set default dimensions for attention head key and value sizes.
    if (d_model / 2) % n_heads != 0:
        raise ValueError(
            f'n_heads ({n_heads}) must divide d_model/2 ({d_model/2})')
    if d_attention_key is None:
        d_attention_key = d_model // n_heads
    if d_attention_value is None:
        d_attention_value = d_model // n_heads

    # Set values of d_model, d_ff and d_qkv for the first stage.
    d_model1, d_ff1 = d_model, d_ff
    d_attention_key1, d_attention_value1 = d_attention_key, d_attention_value
    if half_before_layer:
        d_model1, d_ff1 = d_model / 2, d_ff / 2
        d_attention_key1 = d_attention_key / 2
        d_attention_value1 = d_attention_value / 2

    # Set values of d_model, d_ff and d_qkv for the final stage.
    d_model2, d_ff2 = d_model, d_ff
    d_attention_key2, d_attention_value2 = d_attention_key, d_attention_value
    if double_after_layer:
        d_model2, d_ff2 = d_model * 2, d_ff * 2
        d_attention_key2 = d_attention_key * 2
        d_attention_value2 = d_attention_value * 2

    # Vector embeddings.
    in_encoder, out_encoder, output_vocab_size = (
        ct.EmbeddingAndPositionalEncodings(
            input_vocab_size,
            d_model1,
            mode,
            dropout,
            [-2],  # dropout_shared_axes
            max_len,
            output_vocab_size=output_vocab_size,
            pos_type=pos_type,
            pos_axial_shape=pos_axial_shape,
            pos_d_axial_embs=pos_d_axial_embs,
            pos_start_from_zero_prob=pos_start_from_zero_prob,
            pos_max_offset_to_add=pos_max_offset_to_add,
            use_bfloat16=use_bfloat16))

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        EncoderBlock(d_model1,
                     d_ff1,
                     n_heads,
                     encoder_attention_type,
                     dropout=dropout,
                     ff_activation=ff_activation,
                     ff_dropout=ff_dropout,
                     ff_use_sru=ff_use_sru,
                     ff_chunk_size=ff_chunk_size,
                     ff_sparsity=ff_sparsity,
                     attention_chunk_size=attention_chunk_size,
                     center_layernorm=center_layernorm,
                     use_bfloat16=use_bfloat16,
                     use_two_swaps_per_block=use_two_swaps_per_encoder_block,
                     mode=mode) for _ in range(n_encoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    encoder = [  # vec_e mask_e tok_e tok_d tok_d
        tl.ReversibleSelect([0, 0]),  # vec_e1 vec_e2 mask_e tok_e tok_d tok_d
        _ReversibleSerialForget(encoder_blocks, d_model1, n_layers_forget,
                                forget_dense)
    ]
    if not reversible_encoder:
        encoder += [
            _XYAvg(),
            tl.Dense(d_model1, use_bfloat16=use_bfloat16),
            tl.LayerNorm(),
        ]
    encoder = tl.Serial(encoder)
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    decoder_blocks = []

    if isinstance(encoder_decoder_attention_type, (tuple, list)):
        assert n_decoder_layers % len(encoder_decoder_attention_type) == 0
    else:
        encoder_decoder_attention_type = [encoder_decoder_attention_type]
    for layer_idx in range(n_decoder_layers):
        layer_attention_type = encoder_decoder_attention_type[
            layer_idx % len(encoder_decoder_attention_type)]
        # Grow d_model, d_ff, and d_qkv if requested.
        d_m, d_f, d_k, d_v = d_model1, d_ff1, d_attention_key1, d_attention_value1
        if half_before_layer and layer_idx >= half_before_layer:
            d_m, d_f, d_k, d_v = d_model, d_ff, d_attention_key, d_attention_value
        if double_after_layer and layer_idx > double_after_layer:
            d_m, d_f, d_k, d_v = d_model2, d_ff2, d_attention_key2, d_attention_value2
        decoder_block = DecoderBlock(
            d_m,
            d_f,
            d_k,
            d_v,
            n_heads,
            attention_type=layer_attention_type,
            dropout=dropout,
            ff_activation=ff_activation,
            ff_dropout=ff_dropout,
            ff_use_sru=ff_use_sru,
            ff_chunk_size=ff_chunk_size,
            ff_sparsity=ff_sparsity,
            attention_chunk_size=attention_chunk_size,
            n_attention_layers=n_decoder_attention_layers,
            center_layernorm=center_layernorm,
            use_bfloat16=use_bfloat16,
            mode=mode)
        decoder_blocks.append(decoder_block)
        if half_before_layer and layer_idx == half_before_layer - 1:
            decoder_blocks.append(tl.ReversibleConcatenatePair())
        if double_after_layer and layer_idx == double_after_layer:
            decoder_blocks.append(tl.ReversibleConcatenatePair())

    dense_loss_layer = tl.SparseDenseWithOptions(
        output_vocab_size,
        d_input=d_model2,
        sparsity_type=loss_sparsity_type,
        sparsity=loss_sparsity,
        d_lowrank=loss_d_lowrank,
        prob_sparse=loss_sparsity_prob,
        use_bfloat16=use_bfloat16,
        mode=mode)

    # Layers to merge encoder and decoder, see below for details.
    if reversible_encoder:
        encdec_layers = [
            tl.ReversibleSelect([0, 1, 4, 2,
                                 3]),  # vec_e vec_d mask_e tok_e tok_d
            t2.ConcatWithPadding2(mode=mode),  # vec_ed vec_ed tok_e tok_d
        ]
    else:
        encdec_layers = [
            tl.ReversibleSelect([0, 3, 1,
                                 2]),  # vec_e vec_d mask_e tok_e tok_d
            t2.ConcatWithPadding(mode=mode),  # vec_ed tok_e tok_d
            tl.ReversibleSelect([0, 0]),  # vec_ed vec_ed tok_e tok_d
        ]

    if input_vocab_size is not None:
        # Input in this case is tok_e, tok_d.
        mask_layers = [
            tl.PaddingMask(),
            _RemoveAxes12(),
        ]
        inp_layers = tl.Serial([
            tl.Select([0, 0, 0, 1]),  # tok_e tok_e tok_e tok_d
            tl.Parallel(in_encoder, mask_layers)  # vec_e mask_e tok_e tok_d
        ])
        inp_layers = tl.AssertFunction('bt,bu->btf,bt,bt,bu', inp_layers)
    else:
        # Input in this case is vec_e, mask_e, tok_d. Where all downstream
        # operations expect tok_e, we give it instead mask_e, expecting that
        # downstream ops only are looking for padding/not padding.
        inp_layers = tl.Serial([
            tl.Select([0, 1, 1, 2]),  # vec_e mask_e tok_e tok_d
            tl.Parallel(in_encoder, [],
                        _AsTokenIDs())  # vec_e mask_e tok_e tok_d
        ])
        inp_layers = tl.AssertFunction('btg,bt,bu->btf,bt,bt,bu', inp_layers)

    # Assemble and return the model.
    return tl.Serial(
        inp_layers,  # vec_e mask_e tok_e tok_d

        # Copy decoder tokens for use in loss.
        tl.Select([0, 1, 2, 3, 3]),  # vec_e mask_e tok_e tok_d tok_d

        # Embed in and out tokens; done together as weights may be shared.
        tl.Parallel(
            [],
            [],
            [],  # vec_e mask_e tok_e vec_d tok_d
            [tl.ShiftRight(mode=mode), out_encoder]),

        # Predict mode doesn't work with padding in encoder. Raising an exception
        # in jitted function isn't possible, so the second next best thing is
        # to convert every embedding to NaNs, so the user will not get subtly
        # wrong results, but clearly wrong results.
        (_ConvertToNaNsOnAnyZero() if mode == 'predict' else []),

        # Encode.
        encoder,  # vec_e mask_e tok_e vec_d tok_d

        # Concat encoder and decoder, given encoder mask.
        encdec_layers,

        # Run decoder blocks.
        _ReversibleSerialForget(decoder_blocks, d_model2, n_layers_forget,
                                forget_dense),  # vec_ed1 vec_ed2 tok_e tok_d
        _XYAvg(),  # vec_ed tok_e tok_d
        tl.LayerNorm(),  # vec_ed tok_e tok_d

        # Separate out the encoder part from the concatenated vector.
        tl.Select([0, 1, 2, 2]),  # vec_ed tok_e tok_d tok_d
        t2.StripFromConcatenateWithPadding(mode=mode),  # vec_d tok_d

        # Map to output vocab.
        dense_loss_layer,  # vec_d tok_d
    )
Example #4
0
def Reformer2(input_vocab_size,
              output_vocab_size=None,
              d_model=512,
              d_ff=2048,
              d_attention_key=None,
              d_attention_value=None,
              n_encoder_layers=6,
              n_decoder_layers=6,
              n_heads=8,
              dropout=0.1,
              max_len=2048,
              encoder_attention_type=tl.SelfAttention,
              encoder_decoder_attention_type=tl.SelfAttention,
              pos_type='fixed-base',
              pos_axial_shape=(),
              pos_d_axial_embs=None,
              ff_activation=tl.Relu,
              ff_use_sru=0,
              ff_chunk_size=0,
              ff_dropout=None,
              ff_sparsity=0,
              loss_sparsity_type='mult',
              loss_sparsity=0,
              loss_d_lowrank=0,
              loss_sparsity_prob=None,
              attention_chunk_size=0,
              n_layers_forget=0,
              n_decoder_attention_layers=2,
              use_bfloat16=False,
              reversible_encoder=False,
              use_two_swaps_per_encoder_block=True,
              mode='train'):
  """Reversible transformer encoder-decoder model.

  This model expects an input pair: source, target.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    encoder_attention_type: class: attention class to use, such as SelfAttention
    encoder_decoder_attention_type: class: attention class to use, such as
      SelfAttention
    pos_type: string, the type of positional embeddings to use.
    pos_axial_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    pos_d_axial_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match pos_axial_shape, and values must sum to d_model.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    loss_sparsity_type: str, type of sparsity to used in loss layer. See
      SparseDenseWithOptions for options. None if no sparsity should be used.
    loss_sparsity: int, the sparsity for loss layer (if used)
    loss_d_lowrank: int, the dimensions for intermediate layer (if used)
    loss_sparsity_prob: float, the probability for sparse version of loss to be
      used. If None, only sparse version is used.
    attention_chunk_size: int, if > 0 run attention chunked at this size
    n_layers_forget: how often to have a forgetting block between layers
    n_decoder_attention_layers: how many attention layers in a decoder block
    use_bfloat16: whether to use bfloat16 for weights (default: False)
    reversible_encoder: whether to be reversible through the encoder
    use_two_swaps_per_encoder_block: whether to allow even number of swaps in
      the encoder
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
  # Set default dimensions for attention head key and value sizes.
  if d_attention_key is None:
    if d_model % n_heads != 0:
      raise ValueError(f'n_heads ({n_heads}) must divide d_model ({d_model})')
    d_attention_key = d_model // n_heads
  if d_attention_value is None:
    if d_model % n_heads != 0:
      raise ValueError(f'n_heads ({n_heads}) must divide d_model ({d_model})')
    d_attention_value = d_model // n_heads

  # Vector embeddings.
  in_encoder, out_encoder, output_vocab_size = (
      ct.EmbeddingAndPositionalEncodings(
          input_vocab_size,
          d_model,
          mode,
          dropout,
          [-2],  # dropout_shared_axes
          max_len,
          output_vocab_size=output_vocab_size,
          pos_type=pos_type,
          pos_axial_shape=pos_axial_shape,
          pos_d_axial_embs=pos_d_axial_embs,
          use_bfloat16=use_bfloat16)
  )

  # pylint: disable=g-complex-comprehension
  encoder_blocks = [
      EncoderBlock(
          d_model, d_ff, n_heads, encoder_attention_type,
          dropout=dropout,
          ff_activation=ff_activation,
          ff_dropout=ff_dropout,
          ff_use_sru=ff_use_sru,
          ff_chunk_size=ff_chunk_size,
          ff_sparsity=ff_sparsity,
          attention_chunk_size=attention_chunk_size,
          use_bfloat16=use_bfloat16,
          use_two_swaps_per_block=use_two_swaps_per_encoder_block,
          mode=mode)
      for _ in range(n_encoder_layers)]
  # pylint: enable=g-complex-comprehension

  encoder = [     # vec_e mask_e tok_e tok_d tok_d
      tl.ReversibleSelect([0, 0]),     # vec_e1 vec_e2 mask_e tok_e tok_d tok_d
      _ReversibleSerialForget(encoder_blocks, d_model, n_layers_forget)
  ]
  if not reversible_encoder:
    encoder += [
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),
        tl.Dense(d_model, use_bfloat16=use_bfloat16),
        tl.LayerNorm(),
    ]
  encoder = tl.Serial(encoder)
  if mode == 'predict':
    encoder = tl.Cache(encoder)

  decoder_blocks = []

  if isinstance(encoder_decoder_attention_type, (tuple, list)):
    assert n_decoder_layers % len(encoder_decoder_attention_type) == 0
  else:
    encoder_decoder_attention_type = [encoder_decoder_attention_type]
  for layer_idx in range(n_decoder_layers):
    layer_attention_type = encoder_decoder_attention_type[
        layer_idx % len(encoder_decoder_attention_type)]
    decoder_block = DecoderBlock(
        d_model, d_ff, d_attention_key, d_attention_value, n_heads,
        attention_type=layer_attention_type,
        dropout=dropout,
        ff_activation=ff_activation,
        ff_dropout=ff_dropout,
        ff_use_sru=ff_use_sru,
        ff_chunk_size=ff_chunk_size,
        ff_sparsity=ff_sparsity,
        attention_chunk_size=attention_chunk_size,
        n_attention_layers=n_decoder_attention_layers,
        use_bfloat16=use_bfloat16,
        mode=mode)
    decoder_blocks.append(decoder_block)

  dense_loss_layer = tl.SparseDenseWithOptions(
      output_vocab_size,
      d_input=d_model,
      sparsity_type=loss_sparsity_type,
      sparsity=loss_sparsity,
      d_lowrank=loss_d_lowrank,
      prob_sparse=loss_sparsity_prob,
      use_bfloat16=use_bfloat16,
      mode=mode)

  # Layers to merge encoder and decoder, see below for details.
  if reversible_encoder:
    encdec_layers = [
        tl.ReversibleSelect([0, 1, 4, 2, 3]),  # vec_e vec_d mask_e tok_e tok_d
        t2.ConcatWithPadding2(mode=mode),      # vec_ed vec_ed tok_e tok_d
    ]
  else:
    encdec_layers = [
        tl.ReversibleSelect([0, 3, 1, 2]),     # vec_e vec_d mask_e tok_e tok_d
        t2.ConcatWithPadding(mode=mode),       # vec_ed tok_e tok_d
        tl.ReversibleSelect([0, 0]),           # vec_ed vec_ed tok_e tok_d
    ]

  # Assemble and return the model.
  return tl.Serial(
      # Input: encoder_side_tokens, decoder_side_tokens
      # Copy decoder tokens for use in loss.
      tl.Select([0, 0, 0, 1, 1]),                # tok_e tok_e tok_e tok_d tok_d

      # Embed in and out tokens; done together as weights may be shared.
      tl.Parallel(in_encoder, [], [],            # vec_e tok_e tok_e vec_d tok_d
                  [tl.ShiftRight(mode=mode), out_encoder]),

      tl.Parallel([], [tl.PaddingMask(),
                       tl.Fn('Squeeze',
                             lambda x: jnp.squeeze(x, (1, 2)), n_out=1)]),
      #                                         # vec_e mask_e tok_e vec_d tok_d

      # Encode.
      encoder,                                  # vec_e mask_e tok_e vec_d tok_d

      # Concat encoder and decoder, given encoder mask.
      encdec_layers,

      # Run decoder blocks.
      _ReversibleSerialForget(decoder_blocks, d_model,
                              n_layers_forget),    # vec_ed1 vec_ed2 tok_e tok_d
      tl.Fn('XYAvg',
            lambda x, y: (x + y) / 2.0),           # vec_ed tok_e tok_d
      tl.LayerNorm(),                              # vec_ed tok_e tok_d

      # Separate out the encoder part from the concatenated vector.
      tl.Select([0, 1, 2, 2]),                        # vec_ed tok_e tok_d tok_d
      t2.StripFromConcatenateWithPadding(mode=mode),  # vec_d tok_d

      # Map to output vocab.
      dense_loss_layer,  # vec_d tok_d
  )
Example #5
0
def FunnelTransformer(vocab_size,
                      d_model=512,
                      d_ff=2048,
                      encoder_segment_lengths=(2, 2, 2),
                      n_decoder_blocks=2,
                      n_heads=8,
                      max_len=2048,
                      dropout=0.1,
                      dropout_shared_axes=None,
                      mode='train',
                      ff_activation=tl.Relu,
                      pool_layer=tl.AvgPool,
                      pool_size=(2, ),
                      separate_cls=True):
    """Returns a Full Funnel Transformer, that can be used for example for BERT.

  This model outputs token-level categorical distributions over all vocab:

    - input: rank 2 tensor representing a batch of text strings via token IDs
      plus padding markers; shape is (batch_size, sequence_length). The tensor
      elements are integers in `range(vocab_size)`, and `0` values mark padding
      positions.

    - output: rank 3 tensor representing a batch of log-probability
      distributions over `vocab_size` categories for each token; shape is
      (batch_size, sequence_length, vocab_size).


  Args:
    vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in `range(vocab_size)`. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each encoder
        block.
    encoder_segment_lengths: Tuple, where each element denotes the number of
        transformer encoder blocks preceding a funnel transformer block.
        There is no funnel block after the last sequence of encoder blocks,
        therefore the total number of blocks in the model is equal to
        `sum(encoder_segment_lengths) + len(encoder_segment_lengths) - 1`.
    n_decoder_blocks: Number of transformer blocks in the upsampling decoder.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within an encoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
        a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: If `'train'`, each encoder block will include dropout; else, it will
        pass all values through unaltered.
    ff_activation: Type of activation function at the end of each encoder
        block; must be an activation-type subclass of `Layer`.
    pool_layer: Type of pooling layer used for downsampling in each of the
        funnel blocks; should be `tl.AvgPool` or `tl.MaxPool`.
    pool_size: Shape of window that gets reduced to a single vector value.
        If the layer inputs are :math:`n`-dimensional arrays, then `pool_size`
        must be a tuple of length :math:`n-2`.
    separate_cls: If `True`, pooling in funnel blocks is not applied to
        embeddings of the first token (`cls` from BERT paper) and only final
        embedding of this token is used for categorization - the rest are
        discarded. If `False`, each token from the beginning is pooled and
        all embeddings are averaged and mapped to output categories like in
        original `TransformerEncoder` model.
  """
    assert encoder_segment_lengths

    positional_encoder = [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
        tl.PositionalEncoding(max_len=max_len)
    ]

    n_encoder_segments = len(encoder_segment_lengths)

    encoder_blocks_before_first_pooling = [
        _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                      mode, ff_activation)
        for _ in range(encoder_segment_lengths[0])
    ]
    encoder_blocks_from_first_pooling = []

    for i in range(1, n_encoder_segments):
        # Building i'th segment

        # Add funnel block between segments
        encoder_blocks_from_first_pooling.append(
            _FunnelBlock(d_model,
                         d_ff,
                         n_heads,
                         dropout,
                         dropout_shared_axes,
                         mode,
                         ff_activation,
                         pool_layer,
                         pool_size=pool_size,
                         strides=pool_size,
                         separate_cls=separate_cls))

        for _ in range(encoder_segment_lengths[i]):
            # Create segment_size encoder blocks
            encoder_blocks_from_first_pooling.append(
                _EncoderBlock(d_model, d_ff, n_heads, dropout,
                              dropout_shared_axes, mode, ff_activation))

    decoder_blocks = [
        _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                      mode, ff_activation) for _ in range(n_decoder_blocks)
    ]

    total_pool_size = pool_size[0]**(len(encoder_segment_lengths) - 1)

    # Assemble and return the model.
    return tl.Serial(  # toks
        tl.Branch(positional_encoder, tl.PaddingMask()),  # vecs masks
        encoder_blocks_before_first_pooling,  # vecs masks
        tl.Select([0, 1, 0, 1]),
        # vecs masks residual = vecs old_masks
        encoder_blocks_from_first_pooling,  # vecs masks residual masks
        tl.Select([0, 2, 3]),  # vecs residual masks
        tl.Parallel(
            # residual from first segment is taken before
            # normalization, so apply it now
            None,
            tl.LayerNorm(),
            None),  # vecs norm(residual) masks
        _Upsampler(total_pool_size, separate_cls),  # vecs masks
        decoder_blocks,
        tl.Select([0], n_in=2),  # vecs
        tl.LayerNorm(),
        tl.Dense(vocab_size),
    )
Example #6
0
def SerializedPolicy(seq_model, n_controls, n_actions, observation_serializer,
                     action_serializer):
    """Wraps a policy in serialization machinery for training.

  The resulting model takes as input observation and action sequences, and
  serializes them into one sequence similar to SerializedModel, before passing
  to the given sequence model. Adds output heads for action logits and value
  predictions.

  Args:
    seq_model: Trax sequence model taking as input a sequence of symbols and
      outputting a sequence of continuous vectors.
    n_controls: Number of controls.
    n_actions: Number of action categories in each control.
    observation_serializer: Serializer to use for observations.
    action_serializer: Serializer to use for actions.

  Returns:
    A model of signature (obs, act) -> (act_logits, values), same as in
    RawPolicy.
  """
    if action_serializer.representation_length != n_controls:
        raise ValueError(
            'Action symbols should correspond 1-1 to controls, but got {} '
            'controls and {} symbols.'.format(
                n_controls, action_serializer.representation_length))

    def FirstSymbol():
        return tl.Fn('FirstSymbol', lambda x: x[:, :, 0])

    def PadRight(n_to_pad):
        def pad_right(x):
            pad_widths = [(0, 0), (0, n_to_pad)] + [(0, 0)] * (x.ndim - 2)
            return jnp.pad(x,
                           pad_widths,
                           mode='constant',
                           constant_values=x.dtype.type(0))

        return tl.Fn(f'PadRight({n_to_pad})', pad_right)

    action_head = [
        tl.Dense(n_actions),
        tl.LogSoftmax(),
    ]
    value_head = [
        # Take just the vectors corresponding to the first action symbol.
        FirstSymbol(),
        # Predict values.
        tl.Dense(1),
        # Get rid of the singleton dimension.
        tl.Flatten(),
    ]
    return tl.Serial(
        # (obs, act)
        tl.Parallel(Serialize(observation_serializer),
                    Serialize(action_serializer)),
        # (obs_repr, act_repr)
        Interleave(),
        # (obs_act_repr,)

        # Add one dummy action to the right - we'll use the output at its first
        # symbol to predict the value for the last observation.
        PadRight(action_serializer.representation_length),

        # Shift one symbol to the right, so we predict the n-th action symbol
        # based on action symbols 1..n-1 instead of 1..n.
        tl.ShiftRight(),
        seq_model,
        # (obs_act_hidden,)
        Deinterleave(observation_serializer.representation_length,
                     action_serializer.representation_length),
        # (obs_hidden, act_hidden)
        tl.Select([1, 1]),
        # (act_hidden, act_hidden)
        tl.Parallel(action_head, value_head),
        # (act_logits, values)
    )
Example #7
0
def ConfigurableTerraformer(input_vocab_size,
                            output_vocab_size=None,
                            d_model=512,
                            d_ff=2048,
                            d_attention_key=None,
                            d_attention_value=None,
                            n_encoder_layers=6,
                            n_decoder_layers=6,
                            n_heads=8,
                            dropout=0.1,
                            max_len=2048,
                            encoder_attention_type=tl.SelfAttention,
                            encoder_decoder_attention_type=tl.SelfAttention,
                            pos_type='fixed-base',
                            pos_axial_shape=(),
                            pos_d_axial_embs=None,
                            pos_start_from_zero_prob=1.0,
                            pos_max_offset_to_add=0,
                            ff_activation=tl.Relu,
                            ff_use_sru=0,
                            ff_chunk_size=0,
                            ff_dropout=None,
                            ff_sparsity=0,
                            loss_sparsity_type='mult',
                            loss_sparsity=0,
                            loss_d_lowrank=0,
                            loss_sparsity_prob=None,
                            attention_chunk_size=0,
                            n_layers_forget=0,
                            forget_dense=True,
                            n_decoder_attention_layers=2,
                            use_bfloat16=False,
                            reversible_encoder=False,
                            use_two_swaps_per_encoder_block=True,
                            center_layernorm=True,
                            half_before_layer=None,
                            double_after_layer=None,
                            mode='train'):
    """Returns a highly configurable Terraformer encoder-decoder model.

  This model maps paired text sequences (source and target) to float-valued
  losses. If ``input_vocab_size`` is not ``None``, the layer takes
  two input sequences:

    - inputs (2):

        - source: 2-D int array representing a batch of text strings via token
          IDs plus padding markers; shape is `(batch_size, sequence_length)`,
          where sequence_length <= ``max_len``. Array elements are in
          ``range(input_vocab_size)``, and 0 values mark padding positions.

        - target: 2-D int array representing a batch of text strings via token
          IDs plus padding markers; shape is `(batch_size, sequence_length)`,
          where sequence_length <= ``max_len``. Array elements are in
          ``range(output_vocab_size)``, and 0 values mark padding positions.

    - output: 1-D float array of losses; shape is `(batch_size)`.

  If ``input_vocab_size`` is ``None``, the layer takes three input sequences:

    - inputs (3):

        - source: 3-D float array representing a batch of already-embedded text
          strings; shape is `(batch_size, sequence_length, d_model)`, where
          sequence_length <= ``max_len``.

        - mask: 2-D int array representing active versus masked positions; 0
          values mark masked (padding) positions.

        - target: 2-D int array representing a batch of text strings via token
          IDs plus padding markers; shape is `(batch_size, sequence_length)`,
          where sequence_length <= ``max_len``. Array elements are in
          ``range(output_vocab_size)``, and 0 values mark padding positions.

    - output: 1-D float array of losses; shape is `(batch_size)`.

  Args:
    input_vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in ``range(vocab_size)``. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    output_vocab_size: If specified, gives the vocabulary size for the targets;
        if ``None``, then input and target integers (token IDs) are assumed to
        come from the same vocabulary.
    d_model: Last/innermost dimension of activation arrays at most points in
        the model, including the initial embedding output.
    d_ff: Last/innermost dimension of special (typically wider)
        :py:class:`Dense` layer in the feedforward part of each encoder block.
    d_attention_key: Depth of key vectors in each attention head.
    d_attention_value: Depth of value vectors in each attention head.
    n_encoder_layers: Number of encoder blocks.
    n_decoder_layers: Number of decoder blocks.
    n_heads: Number of attention heads.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within encoder/decoder blocks. The same rate is
        also used for attention dropout in encoder/decoder blocks.
    max_len: Maximum symbol length for positional encoding.
    encoder_attention_type: Type of attention to use in the encoder; must be
        an attention-type subclass of :py:class:`trax.layers.Layer`.
    encoder_decoder_attention_type: Type of attention to use in the decoder;
        must be an attention-type subclass of :py:class:`trax.layers.Layer`.
    pos_type: String indicating the type of positional embeddings to use.
    pos_axial_shape: Shape (tuple of ints) to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    pos_d_axial_embs: Tuple of ints specifying the depth of position embedding
        for each axis. Tuple length must match ``pos_axial_shape``, and values
        must sum to ``d_model``.
    pos_start_from_zero_prob: Stochastic rate (probability) for starting
        positional encoding at position 0 during training. If 1.0, always start
        from position 0; if < 1.0, the non-zero starts will be uniformly
        distributed up to ``pos_max_offset_to_add``.
    pos_max_offset_to_add: Maximum offset to add to positions during training
        when randomizing. This offset plus input length must be less than
        ``max_len`` for all training examples.
    ff_activation: Type of activation function at the end of each block; must
        be an activation-type subclass of :py:class:`trax.layers.Layer`.
    ff_use_sru: If > 0, use this number of SRU layers in place of feedforward
        layers.
    ff_chunk_size: If > 0, chunk each feedforward layer into chunks of this
        size.
    ff_dropout: Stochastic rate (probability) for dropping an activation value
        at feedforward nonlinearities.
    ff_sparsity: If > 0, use sparse feedforward blocks with this level of
        sparsity.
    loss_sparsity_type: String indicating the type of sparsity to used in loss
        layer; see :py:class:`SparseDenseWithOptions` for options. If ``None``,
        use no sparsity.
    loss_sparsity: If > 0, use this level of sparsity in the loss layer.
    loss_d_lowrank: If > 0, use a (low-rank) intermediate layer, with this
        dimension, in the loss.
    loss_sparsity_prob: Stochastic rate (probability) for using the sparse
        version of the loss. If ``None``, use the sparse version exclusively.
    attention_chunk_size: If > 0, compute attention using chunks of this size.
    n_layers_forget: How often to have a forgetting block between layers.
    forget_dense: If True, use :py:class:`Dense` instances as forget layers;
        else use no-ops.
    n_decoder_attention_layers: Number of attention layers in a decoder block.
    use_bfloat16: If True, use bfloat16 for weights; else use float32.
    reversible_encoder: If True, make the encoder be reversible.
    use_two_swaps_per_encoder_block: If True, ensure that there is a an even
        number of swaps across the encoder.
    center_layernorm: If True, use centering in :py:class:`LayerNorm` (the
        default); else omit centering (which is known as RMS normalization).
    half_before_layer: If not None, specifies an n'th layer such that all
        layers before the n'th use half the normal values for ``d_model`` and
        ``d_ff``.
    double_after_layer: If not None, specifies an n'th layer such that all
        layers after the n'th use double the normal values for ``d_model`` and
        ``d_ff``.
    mode: If ``'train'``, include dropout in each encoder/decoder block; else
        dropout layers have no effect.

  Returns:
    A Terraformer encoder-decoder as a layer that maps from target and source
    text sequences to a scalar loss.
  """
    # Set default dimensions for attention head key and value sizes.
    if (d_model / 2) % n_heads != 0:
        raise ValueError(
            f'n_heads ({n_heads}) must divide d_model/2 ({d_model/2})')
    if d_attention_key is None:
        d_attention_key = d_model // n_heads
    if d_attention_value is None:
        d_attention_value = d_model // n_heads

    # Set values of d_model, d_ff and d_qkv for the first stage.
    d_model1, d_ff1 = d_model, d_ff
    d_attention_key1, d_attention_value1 = d_attention_key, d_attention_value
    if half_before_layer:
        d_model1, d_ff1 = d_model / 2, d_ff / 2
        d_attention_key1 = d_attention_key / 2
        d_attention_value1 = d_attention_value / 2

    # Set values of d_model, d_ff and d_qkv for the final stage.
    d_model2, d_ff2 = d_model, d_ff
    d_attention_key2, d_attention_value2 = d_attention_key, d_attention_value
    if double_after_layer:
        d_model2, d_ff2 = d_model * 2, d_ff * 2
        d_attention_key2 = d_attention_key * 2
        d_attention_value2 = d_attention_value * 2

    # Vector embeddings.
    in_encoder, out_encoder, output_vocab_size = (
        ct.EmbeddingAndPositionalEncodings(
            input_vocab_size,
            d_model1,
            mode,
            dropout,
            [-2],  # dropout_shared_axes
            max_len,
            output_vocab_size=output_vocab_size,
            pos_type=pos_type,
            pos_axial_shape=pos_axial_shape,
            pos_d_axial_embs=pos_d_axial_embs,
            pos_start_from_zero_prob=pos_start_from_zero_prob,
            pos_max_offset_to_add=pos_max_offset_to_add,
            use_bfloat16=use_bfloat16))

    def _EncoderBlock():
        return reformer.EncoderBlock(
            d_model1,
            d_ff1,
            n_heads,
            encoder_attention_type,
            dropout=dropout,
            ff_activation=ff_activation,
            ff_dropout=ff_dropout,
            ff_use_sru=ff_use_sru,
            ff_chunk_size=ff_chunk_size,
            ff_sparsity=ff_sparsity,
            attention_chunk_size=attention_chunk_size,
            center_layernorm=center_layernorm,
            use_bfloat16=use_bfloat16,
            use_two_swaps_per_block=use_two_swaps_per_encoder_block,
            mode=mode)

    def _Encoder():  # vec_e mask_e tok_e tok_d tok_d
        layers = [
            tl.ReversibleSelect([0, 0]),
            _ReversibleSerialForget(
                [_EncoderBlock() for _ in range(n_encoder_layers)], d_model1,
                n_layers_forget, forget_dense)
        ]
        if not reversible_encoder:
            layers += [
                _XYAvg(),
                tl.Dense(d_model1, use_bfloat16=use_bfloat16),
                tl.LayerNorm(),
            ]
        if mode == 'predict':
            return tl.Cache(tl.Serial(layers))
        else:
            return tl.Serial(layers)

    decoder_blocks = []

    if isinstance(encoder_decoder_attention_type, (tuple, list)):
        assert n_decoder_layers % len(encoder_decoder_attention_type) == 0
    else:
        encoder_decoder_attention_type = [encoder_decoder_attention_type]
    for layer_idx in range(n_decoder_layers):
        layer_attention_type = encoder_decoder_attention_type[
            layer_idx % len(encoder_decoder_attention_type)]
        # Grow d_model, d_ff, and d_qkv if requested.
        d_m, d_f, d_k, d_v = d_model1, d_ff1, d_attention_key1, d_attention_value1
        if half_before_layer and layer_idx >= half_before_layer:
            d_m, d_f, d_k, d_v = d_model, d_ff, d_attention_key, d_attention_value
        if double_after_layer and layer_idx > double_after_layer:
            d_m, d_f, d_k, d_v = d_model2, d_ff2, d_attention_key2, d_attention_value2
        decoder_block = reformer.DecoderBlock(
            d_m,
            d_f,
            d_k,
            d_v,
            n_heads,
            attention_type=layer_attention_type,
            dropout=dropout,
            ff_activation=ff_activation,
            ff_dropout=ff_dropout,
            ff_use_sru=ff_use_sru,
            ff_chunk_size=ff_chunk_size,
            ff_sparsity=ff_sparsity,
            attention_chunk_size=attention_chunk_size,
            n_attention_layers=n_decoder_attention_layers,
            center_layernorm=center_layernorm,
            use_bfloat16=use_bfloat16,
            mode=mode)
        decoder_blocks.append(decoder_block)
        if half_before_layer and layer_idx == half_before_layer - 1:
            decoder_blocks.append(tl.ReversibleConcatenatePair())
        if double_after_layer and layer_idx == double_after_layer:
            decoder_blocks.append(tl.ReversibleConcatenatePair())

    def _Loss():
        return tl.SparseDenseWithOptions(output_vocab_size,
                                         d_input=d_model2,
                                         sparsity_type=loss_sparsity_type,
                                         sparsity=loss_sparsity,
                                         d_lowrank=loss_d_lowrank,
                                         prob_sparse=loss_sparsity_prob,
                                         use_bfloat16=use_bfloat16,
                                         mode=mode)

    def _enc_dec_concat():
        """Layers to merge encoder and decoder."""
        if reversible_encoder:
            return [
                tl.ReversibleSelect([0, 1, 4, 2,
                                     3]),  # v_e v_d mask_e tok_e tok_d
                t2.ConcatWithPadding2(mode=mode),  # v_ed v_ed tok_e tok_d
            ]
        else:
            return [
                tl.ReversibleSelect([0, 3, 1,
                                     2]),  # v_e v_d mask_e tok_e tok_d
                t2.ConcatWithPadding(mode=mode),  # v_ed tok_e tok_d
                tl.ReversibleSelect([0, 0]),  # v_ed v_ed tok_e tok_d
            ]

    def _inp_layers():
        if input_vocab_size is not None:
            return tl.AssertFunction(
                'bl,br->bld,bl,bl,br',  # b: batch, l/r: enc/dec length, d: vec depth
                tl.Serial(  # tok_e tok_d
                    tl.Select([0, 0, 0, 1]),
                    tl.Parallel(
                        in_encoder,
                        [tl.PaddingMask(), _RemoveAxes12()
                         ])))  # vec_e mask_e tok_e tok_d
        else:
            # Input in this case is vec_e, mask_e, tok_d. Where all downstream
            # operations expect tok_e, we give it instead mask_e, expecting that
            # downstream ops only are looking for padding/not padding.
            return tl.AssertFunction(
                'blf,bl,br->bld,bl,bl,br',  # f: in-feature depth, d: out-vector depth
                tl.Serial(  # vec_e mask_e tok_d
                    tl.Select([0, 1, 1, 2]),
                    tl.Parallel(in_encoder, [],
                                _AsTokenIDs())))  # vec_e mask_e tok_e tok_d

    # Assemble and return the model.
    return tl.Serial(
        _inp_layers(),  # vec_e mask_e tok_e tok_d
        tl.Select([0, 1, 2, 3, 3]),  # Copy decoder tokens for use in loss.

        # Embed in and out tokens; done together as weights may be shared.
        tl.Parallel([], [], [], [tl.ShiftRight(mode=mode), out_encoder
                                 ]),  # vec_e mask_e tok_e vec_d tok_d

        # Predict mode doesn't work with padding in encoder. Raising an exception
        # in jitted function isn't possible, so the next best thing is to convert
        # every embedding to NaNs, so the user will get unmistakably wrong
        # results.
        (_ConvertToNaNsOnAnyZero() if mode == 'predict' else []),

        # Encode; then concat encoder and decoder, given encoder mask.
        _Encoder(),  # vec_e mask_e tok_e vec_d tok_d
        _enc_dec_concat(),

        # Run decoder blocks.
        _ReversibleSerialForget(decoder_blocks, d_model2, n_layers_forget,
                                forget_dense),  # vec_ed1 vec_ed2 tok_e tok_d
        _XYAvg(),  # vec_ed tok_e tok_d
        tl.LayerNorm(),  # vec_ed tok_e tok_d

        # Separate out the encoder part from the concatenated vector,
        # then compute loss.
        tl.Select([0, 1, 2, 2]),  # vec_ed tok_e tok_d tok_d
        t2.StripFromConcatenateWithPadding(mode=mode),  # vec_d tok_d
        _Loss(),  # vec_d tok_d
    )
Example #8
0
def policy_and_value_net(bottom_layers_fn, observation_space, action_space,
                         vocab_size, two_towers):
    """A policy and value net function.

  Runs bottom_layers_fn either as a single network or as two separate towers.
  Attaches action and value heads and wraps the network in a policy wrapper.

  Args:
    bottom_layers_fn: Trax model to use as a policy network.
    observation_space (gym.Space): Observation space.
    action_space (gym.Space): Action space.
    vocab_size (int or None): Vocabulary size to use with a SerializedPolicy
      wrapper. If None, RawPolicy will be used.
    two_towers (bool): Whether to run bottom_layers_fn as two separate towers
      for action and value prediction.

  Returns:
    Pair (network, substitute_fn), where network is the final network and
      substitute_fn is a function (wrapped_tree, inner_tree) -> wrapped_tree
      for substituting weights or state of the constructed model based on the
      weights or state of a model returned from bottom_layers_fn. substitute_fn
      is used for initializing the policy from parameters of a world model.
  """
    kwargs = {}
    if vocab_size is not None:
        kwargs['vocab_size'] = vocab_size

    def wrapped_policy_fn():
        return serialization_utils.wrap_policy(
            bottom_layers_fn(**kwargs),
            observation_space,
            action_space,
            vocab_size,
        )

    # Now, with the current logits, one head computes action probabilities and the
    # other computes the value function.
    # NOTE: The LogSoftmax instead of the Softmax because of numerical stability.
    if two_towers:
        # Two towers: run two two-head networks in parallel and drop one head from
        # each.
        net = tl.Serial([  # (obs, act)
            tl.Select([0, 1, 0, 1]),  # (obs, act, obs, act)
            tl.Parallel(
                wrapped_policy_fn(),
                wrapped_policy_fn(),
            ),  # (act_logits_1, vals_1, act_logits_2, vals_2)
            tl.Select([0, 3]),  # (act_logits_1, vals_2)
        ])

        def substitute_fn(wrapped_policy, inner_policy):
            return (wrapped_policy[:1] + [
                tuple(
                    # Substitute in both towers.
                    serialization_utils.substitute_inner_policy(  # pylint: disable=g-complex-comprehension
                        tower, inner_policy, vocab_size)
                    for tower in wrapped_policy[1])
            ] + [wrapped_policy[2:]])
    else:
        # One tower: run one two-headed network.
        net = wrapped_policy_fn()
        substitute_fn = functools.partial(
            serialization_utils.substitute_inner_policy,
            vocab_size=vocab_size,
        )
    return (net, substitute_fn)
Example #9
0
 def test_dup_dup(self):
     layer = tl.Parallel(tl.Dup(), tl.Dup())
     xs = [np.array([1, 2, 3]), np.array([10, 20])]
     ys = layer(xs)
     self.assertEqual(as_list(ys),
                      [[1, 2, 3], [1, 2, 3], [10, 20], [10, 20]])
def Transformer(input_vocab_size,
                output_vocab_size=None,
                d_model=512,
                d_ff=2048,
                n_encoder_layers=6,
                n_decoder_layers=6,
                n_heads=8,
                dropout=0.1,
                max_len=2048,
                mode='train',
                ff_activation=tl.Relu):
    """Returns a Transformer model.

  This model expects an input pair: target, source.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A Transformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    in_embed = [  # tokens
        tl.Embedding(d_model, input_vocab_size),  # vecs
        tl.Dropout(rate=dropout, mode=mode),  # vecs
        tl.PositionalEncoding(max_len=max_len),  # vecs
    ]

    if output_vocab_size is None:
        output_vocab_size = input_vocab_size
        out_embed = in_embed
    else:
        out_embed = [  # tokens
            tl.Embedding(d_model, output_vocab_size),  # vecs
            tl.Dropout(rate=dropout, mode=mode),  # vecs
            tl.PositionalEncoding(max_len=max_len),  # vecs
        ]

    encoder_stack = (  # masks vectors --> masks vectors
        [
            EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode,
                         ff_activation) for i in range(n_encoder_layers)
        ])

    encoder_decoder_stack = (  # vecs_d masks vecs_e --> vecs_d masks vecs_e
        [
            EncoderDecoder(d_model, d_ff, n_heads, dropout, i, mode,
                           ff_activation) for i in range(n_decoder_layers)
        ])

    # Input: encoder_side_tokens, decoder_side_tokens
    return tl.Serial(  # tokens_e tokens_d
        tl.Parallel([], tl.Dup()),  # toks_e toks_d toks_d (for loss)
        tl.Swap(),  # toks_d toks_e ....

        # Encode.
        tl.Parallel(  # toks_d        toks_e
            [],
            [
                tl.Dup(),  # ______ toks_e toks_e
                tl.Parallel(in_embed, tl.PaddingMask()),  # ______ vecs_e masks
                encoder_stack,  # ______ vecs_e masks
                tl.LayerNorm(),  # ______ vecs_e .....
                tl.Swap()
            ]),  # ______ masks  vecs_e

        # Decode.                                  #        toks_d masks vecs_e
        tl.ShiftRight(),  #        toks_d ..... ......
        out_embed,  #        vecs_d ..... ......
        tl.Dup(),  # vecs_d vecs_d ..... ......
        tl.Parallel([], tl.EncoderDecoderMask()),  # ______    masks     ......
        encoder_decoder_stack,  # vecs_d    masks     vecs_e
        tl.Parallel([], tl.Drop(), tl.Drop()),  # vecs_d
        tl.LayerNorm(),  # vecs_d
        tl.Dense(output_vocab_size),  # vecs_d
        tl.LogSoftmax(),  # vecs_d
    )
Example #11
0
def LSTMSeq2SeqAttn(input_vocab_size=256,
                    target_vocab_size=256,
                    d_model=512,
                    n_encoder_layers=2,
                    n_decoder_layers=2,
                    n_attention_heads=1,
                    attention_dropout=0.0,
                    mode='train'):
    """Returns an LSTM sequence-to-sequence model with attention.

  The input to the model is a pair (input tokens, target tokens), e.g.,
  an English sentence (tokenized) and its translation into German (tokenized).

  The model works as follows:
  * Input encoder runs on the input tokens and creates activations that
    are used as both keys and values in attention.
  * Pre-attention decoder runs on the targets and creates
    activations that are used as queries in attention.
  * Attention runs on the queries, keys and values masking out input padding.
  * Decoder runs on the result, followed by a cross-entropy loss.

  Args:
    input_vocab_size: int: vocab size of the input
    target_vocab_size: int: vocab size of the target
    d_model: int:  depth of embedding (n_units in the LSTM cell)
    n_encoder_layers: int: number of LSTM layers in the encoder
    n_decoder_layers: int: number of LSTM layers in the decoder after attention
    n_attention_heads: int: number of attention heads
    attention_dropout: float, dropout for the attention layer
    mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference

  Returns:
    An LSTM sequence-to-sequence model with attention.
  """
    input_encoder = tl.Serial(
        tl.Embedding(d_model, input_vocab_size),
        [tl.LSTM(d_model) for _ in range(n_encoder_layers)],
    )

    pre_attention_decoder = tl.Serial(
        tl.ShiftRight(mode=mode),
        tl.Embedding(d_model, target_vocab_size),
        tl.LSTM(d_model),
    )

    def PrepareAttentionInputs():
        """Layer that prepares queries, keys, values and mask for attention."""
        def F(encoder_activations, decoder_activations, input_tokens):
            keys = values = encoder_activations
            queries = decoder_activations
            # Mask is 1 where inputs are not padding (0) and 0 where they are padding.
            mask = (input_tokens != 0)
            # We need to add axes to the mask for attention heads and decoder length.
            mask = jnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1]))
            # Broadcast so mask is [batch, 1 for heads, decoder-len, encoder-len].
            mask = mask + jnp.zeros((1, 1, decoder_activations.shape[1], 1))
            return queries, keys, values, mask

        return tl.Fn('PrepareAttentionInputs', F, n_out=4)

    return tl.Serial(  # in-toks, target-toks
        tl.Select([0, 1, 0, 1]),  # in-toks, target-toks, in-toks, target-toks
        tl.Parallel(input_encoder, pre_attention_decoder),
        PrepareAttentionInputs(),  # q, k, v, mask, target-toks
        tl.Residual(
            tl.AttentionQKV(d_model,
                            n_heads=n_attention_heads,
                            dropout=attention_dropout,
                            mode=mode)),  # decoder-vecs, mask, target-toks
        tl.Select([0, 2]),  # decoder-vecs, target-toks
        [tl.LSTM(d_model) for _ in range(n_decoder_layers)],
        tl.Dense(target_vocab_size),
        tl.LogSoftmax())
Example #12
0
def LSTMSeq2SeqAttn(input_vocab_size=256,
                    target_vocab_size=256,
                    d_model=512,
                    n_encoder_layers=2,
                    n_decoder_layers=2,
                    n_attention_heads=1,
                    attention_dropout=0.0,
                    mode='train'):
    """Returns an LSTM sequence-to-sequence model with attention.

  This model is an encoder-decoder that performs tokenized string-to-string
  ("source"-to-"target") transduction:

    - inputs (2):

        - source: rank 2 tensor representing a batch of text strings via token
          IDs plus padding markers; shape is (batch_size, sequence_length). The
          tensor elements are integers in `range(input_vocab_size)`, and `0`
          values mark padding positions.

        - target: rank 2 tensor representing a batch of text strings via token
          IDs plus padding markers; shape is (batch_size, sequence_length). The
          tensor elements are integers in `range(output_vocab_size)`, and `0`
          values mark padding positions.

    - output: rank 3 tensor representing a batch of log-probability
      distributions for each sequence position over possible token IDs;
      shape is (batch_size, sequence_length, `vocab_size`).

  An example use would be to translate (tokenized) sentences from English to
  German.

  The model works as follows:

  * Input encoder runs on the input tokens and creates activations that
    are used as both keys and values in attention.
  * Pre-attention decoder runs on the targets and creates
    activations that are used as queries in attention.
  * Attention runs on the queries, keys and values masking out input padding.
  * Decoder runs on the result, followed by a cross-entropy loss.

  Args:
    input_vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in `range(vocab_size)`. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    target_vocab_size: Target vocabulary size.
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    n_encoder_layers: Number of LSTM layers in the encoder.
    n_decoder_layers: Number of LSTM layers in the decoder after attention.
    n_attention_heads: Number of attention heads.
    attention_dropout: Stochastic rate (probability) for dropping an activation
        value when applying dropout within an attention block.
    mode: If `'predict'`, use fast inference. If `'train'`, each attention block
        will include dropout; else, it will pass all values through unaltered.

  Returns:
    An LSTM sequence-to-sequence model as a layer that maps from a
    source-target tokenized text pair to activations over a vocab set.
  """
    input_encoder = tl.Serial(
        tl.Embedding(input_vocab_size, d_model),
        [tl.LSTM(d_model) for _ in range(n_encoder_layers)],
    )

    pre_attention_decoder = tl.Serial(
        tl.ShiftRight(mode=mode),
        tl.Embedding(target_vocab_size, d_model),
        tl.LSTM(d_model, mode=mode),
    )

    def PrepareAttentionInputs():
        """Layer that prepares queries, keys, values and mask for attention."""
        def F(encoder_activations, decoder_activations, input_tokens):
            keys = values = encoder_activations
            queries = decoder_activations
            # Mask is 1 where inputs are not padding (0) and 0 where they are padding.
            mask = (input_tokens != 0)
            # We need to add axes to the mask for attention heads and decoder length.
            mask = jnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1]))
            # Broadcast so mask is [batch, 1 for heads, decoder-len, encoder-len].
            mask = mask + jnp.zeros((1, 1, decoder_activations.shape[1], 1))
            mask = mask.astype(jnp.float32)
            return queries, keys, values, mask

        return tl.Fn('PrepareAttentionInputs', F, n_out=4)

    return tl.Serial(  # in-toks, target-toks
        tl.Select([0, 1, 0, 1]),  # in-toks, target-toks, in-toks, target-toks
        tl.Parallel(input_encoder, pre_attention_decoder),
        PrepareAttentionInputs(),  # q, k, v, mask, target-toks
        tl.Residual(
            tl.AttentionQKV(
                d_model,
                n_heads=n_attention_heads,
                dropout=attention_dropout,
                mode=mode,
                cache_KV_in_predict=True)),  # decoder-vecs, mask, target-toks
        tl.Select([0, 2]),  # decoder-vecs, target-toks
        [tl.LSTM(d_model, mode=mode) for _ in range(n_decoder_layers)],
        tl.Dense(target_vocab_size),
        tl.LogSoftmax())
Example #13
0
def Dup2():
    """Copy first 2 elements of the stack: (a, b, ...) -> (a, b, a, b, ...)."""
    return [  # Stack is (a, b, ...)
        tl.Parallel(tl.Dup(), tl.Dup()),  # Stack is (a, a, b, b, ...)
        tl.Parallel([], tl.Swap())  # Stack is (a, b, a, b, ...)
    ]
Example #14
0
def LearnedQP(keys=None, values=None, binary=False):
    """Get (query, pos), make learned weight of qeury and return with pos."""
    return tl.Parallel(
        tl.Dense(1),
        QueryPositionKV(keys=keys, values=values, binary=binary),
    )
Example #15
0
 def some_layer():
     return tl.Parallel(DivideBy(2.0), DivideBy(5.0))
Example #16
0
 def test_div_div(self):
     layer = tl.Parallel(DivideBy(0.5), DivideBy(3.0))
     xs = [np.array([1, 2, 3]), np.array([30, 60])]
     ys = layer(xs)
     self.assertEqual(as_list(ys), [[2, 4, 6], [10, 20]])
Example #17
0
def Reformer2(input_vocab_size,
              output_vocab_size=None,
              d_model=512,
              d_ff=2048,
              d_attention_key=None,
              d_attention_value=None,
              n_encoder_layers=6,
              n_decoder_layers=6,
              n_heads=8,
              dropout=0.1,
              max_len=2048,
              encoder_attention_type=tl.SelfAttention,
              encoder_decoder_attention_type=tl.SelfAttention,
              axial_pos_shape='fixed-base',
              d_axial_pos_embs=None,
              ff_activation=tl.Relu,
              ff_use_sru=0,
              ff_chunk_size=0,
              ff_dropout=None,
              ff_sparsity=0,
              n_layers_forget=0,
              mode='train'):
    """Reversible transformer encoder-decoder model.

  This model expects an input pair: source, target.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    encoder_attention_type: class: attention class to use, such as SelfAttention
    encoder_decoder_attention_type: class: attention class to use, such as
      SelfAttention
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    n_layers_forget: how often to have a forgetting block between layers
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    # The current API for custom gradients assumes that a layer must be
    # differentiable wrt all of its inputs, but the Transformer puts bool-dtype
    # masks on the stack. This causes jax to error, even though the so-called
    # "gradient" wrt the masks is never actually computed.
    # TODO(kitaev): remove this hack.
    if fastmath.is_backend(fastmath.Backend.JAX):
        jax.api._check_inexact_input_vjp = lambda x: None  # pylint: disable=protected-access

    # Set default dimensions for attention head key and value sizes.
    if d_attention_key is None:
        if d_model % n_heads != 0:
            raise ValueError(
                f'n_heads ({n_heads}) must divide d_model ({d_model})')
        d_attention_key = d_model // n_heads
    if d_attention_value is None:
        if d_model % n_heads != 0:
            raise ValueError(
                f'n_heads ({n_heads}) must divide d_model ({d_model})')
        d_attention_value = d_model // n_heads

    # Vector embeddings.
    def Embedder(vocab_size):  # tokens --> vectors
        return [
            tl.Embedding(vocab_size, d_model),
            tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),
        ]

    in_embedder = Embedder(input_vocab_size)
    out_embedder = (in_embedder if output_vocab_size is None else
                    Embedder(output_vocab_size))

    def PositionalEnc(mode):
        return PositionalEncoding(mode, dropout, max_len, axial_pos_shape,
                                  d_axial_pos_embs)

    # Mode 'predict' means that the decoder should be run one token at a time.
    # The encoder only ever runs over full sequences, which is why it's switched
    # to 'eval' mode instead.
    encoder_mode = 'eval' if mode == 'predict' else mode
    in_encoder = in_embedder + [PositionalEnc(encoder_mode)]
    out_encoder = out_embedder + [PositionalEnc(mode)]
    if output_vocab_size is None:
        output_vocab_size = input_vocab_size

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        EncoderBlock(d_model,
                     d_ff,
                     n_heads,
                     encoder_attention_type,
                     dropout=dropout,
                     ff_activation=ff_activation,
                     ff_dropout=ff_dropout,
                     ff_use_sru=ff_use_sru,
                     ff_chunk_size=ff_chunk_size,
                     ff_sparsity=ff_sparsity,
                     mode=mode) for _ in range(n_encoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    encoder = tl.Serial([  # vec_e mask_e tok_e tok_d tok_d
        tl.Dup(),  # vec_e1 vec_e2 mask_e tok_e tok_d tok_d
        _ReversibleSerialForget(encoder_blocks, d_model, n_layers_forget),
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),
        tl.Dense(d_model),
        tl.LayerNorm(),
    ])
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    decoder_blocks = []

    if isinstance(encoder_decoder_attention_type, (tuple, list)):
        assert n_decoder_layers % len(encoder_decoder_attention_type) == 0
    else:
        encoder_decoder_attention_type = [encoder_decoder_attention_type]
    for layer_idx in range(n_decoder_layers):
        layer_attention_type = encoder_decoder_attention_type[
            layer_idx % len(encoder_decoder_attention_type)]
        decoder_block = DecoderBlock(d_model,
                                     d_ff,
                                     d_attention_key,
                                     d_attention_value,
                                     n_heads,
                                     attention_type=layer_attention_type,
                                     dropout=dropout,
                                     ff_activation=ff_activation,
                                     ff_dropout=ff_dropout,
                                     ff_use_sru=ff_use_sru,
                                     ff_chunk_size=ff_chunk_size,
                                     ff_sparsity=ff_sparsity,
                                     mode=mode)
        decoder_blocks.append(decoder_block)

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 0, 0, 1, 1]),  # tok_e tok_e tok_e tok_d tok_d

        # Embed in and out tokens; done together as weights may be shared.
        tl.Parallel(
            in_encoder,
            [],
            [],  # vec_e tok_e tok_e vec_d tok_d
            [tl.ShiftRight(mode=mode), out_encoder]),
        tl.Parallel([], [
            tl.PaddingMask(),
            tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1)
        ]),
        #                                         # vec_e mask_e tok_e vec_d tok_d

        # Encode.
        encoder,  # vec_e mask_e tok_e vec_d tok_d

        # Decode.
        tl.Select([3, 0, 1, 2]),  #  vec_d vec_e mask_e tok_e tok_d

        # Concat encoder and decoder, given encoder mask.
        tl.Select([1, 0]),  # vec_e vec_d mask_e tok_e tok_d
        t2.ConcatWithPadding(mode=mode),  # vec_ed tok_e tok_d

        # Run (encoder and) decoder blocks.
        tl.Dup(),  # vec_ed1 vec_ed2 tok_e tok_d
        _ReversibleSerialForget(
            decoder_blocks, d_model,
            n_layers_forget),  # vec_ed1 vec_ed2 tok_e tok_d
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),  # vec_ed tok_e tok_d
        tl.LayerNorm(),  # vec_ed tok_e tok_d

        # Separate out the encoder part from the concatenated vector.
        tl.Select([0, 1, 2, 2]),  # vec_ed tok_e tok_d tok_d
        t2.StripFromConcatenateWithPadding(mode=mode),  # vec_d tok_d

        # Map to output vocab.
        tl.Dense(output_vocab_size),  # vec_d tok_d
        tl.LogSoftmax(),  # vec_d tok_d
    )
Example #18
0
 def test_two_no_ops(self):
     layer = tl.Parallel([], None)
     xs = [np.array([1, 2, 3]), np.array([10, 20])]
     ys = layer(xs)
     self.assertEqual(as_list(ys), [[1, 2, 3], [10, 20]])
Example #19
0
def ReformerShortenLM(vocab_size,
                      shorten_factor=1,
                      d_embedding=256,
                      d_model=512,
                      d_ff=2048,
                      d_attention_key=64,
                      d_attention_value=64,
                      n_layers=6,
                      n_heads=8,
                      dropout=0.1,
                      max_len=2048,
                      n_attention_chunks=1,
                      attention_type=tl.DotProductCausalAttention,
                      share_qk=False,
                      axial_pos_shape=(),
                      d_axial_pos_embs=None,
                      ff_activation=tl.FastGelu,
                      ff_use_sru=0,
                      mode='train'):
    """Reversible transformer language model with shortening.

  When shorten_factor is F and processing an input of shape [batch, length],
  we embed the (shifted-right) input and then group each F elements (on length)
  into a single vector -- so that in the end we process a tensor of shape
    [batch, length // F, d_model]
  almost until the end -- at the end it's un-shortend and a SRU is applied.
  This reduces the length processed inside the main model body, effectively
  making the model faster but possibly slightly less accurate.

  Args:
    vocab_size: int: vocab size
    shorten_factor: by how much to shorten, see above
    d_embedding: the depth of the embedding layer and final logits
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    n_attention_chunks: int: number of chunks for attention
    attention_type: class: attention class to use, such as DotProductAttention.
    share_qk: bool, whether to share queries and keys.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, values must sum to d_embedding.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    if not axial_pos_shape:
        positional_encoding = tl.PositionalEncoding(max_len=max_len,
                                                    dropout=dropout)
    else:
        assert d_axial_pos_embs is not None
        positional_encoding = tl.AxialPositionalEncoding(
            shape=axial_pos_shape,
            d_embs=d_axial_pos_embs,
            dropout_broadcast_dims=tuple(range(1,
                                               len(axial_pos_shape) + 1)),
            dropout=dropout)

    positional_embedder = [
        tl.Embedding(d_embedding, vocab_size),
        BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
        positional_encoding,
    ]

    decoder_blocks = []

    if isinstance(attention_type, (tuple, list)):
        assert n_layers % len(attention_type) == 0
    else:
        attention_type = [attention_type]
    for layer_idx in range(n_layers):
        layer_attention_type = attention_type[layer_idx % len(attention_type)]
        decoder_block = DecoderBlock(
            d_model,
            d_ff,
            d_attention_key,
            d_attention_value,
            n_heads,
            n_attention_chunks,
            attention_type=layer_attention_type,
            dropout=dropout,
            share_qk=(share_qk or issubclass(layer_attention_type,
                                             tl.LSHCausalAttention)),
            ff_activation=ff_activation,
            ff_use_sru=ff_use_sru,
            mode=mode)
        decoder_blocks.append(decoder_block)

    # pylint: disable=g-long-lambda
    return tl.Serial(
        tl.ShiftRight(),
        positional_embedder,
        tl.Dup(),  # Stack has (x, x), the first will be shortened
        # Before shortening, we need to pad by shorten factor so as not to leak
        # information into the future. To understand why, imagine shorten factor
        # of 2 and sequence of length 4, so ABCD. If we shift just by 1, then we
        # would have 0ABC, which gets grouped to [0A][BC] on input, which is
        # predicting ABCD as targets. The problem is that [0A] has access to A
        # and [BC] has access to C -- it will learn to copy it, peek into
        # the future. Shifting twice to [00][AB] solves the problem as the first
        # "big" symbol becomes all-0 and the rest is shifted enough.
        tl.ShiftRight(n_shifts=shorten_factor - 1),
        tl.Fn(
            lambda x: np.reshape(  # Shorten -- move to depth.
                x, (x.shape[0], x.shape[1] // shorten_factor, -1)),
            n_out=1),
        tl.Dense(d_model),
        tl.Dup(),  # Stack has (short_x, short_x, x)
        tl.ReversibleSerial(decoder_blocks),
        tl.Parallel([], tl.Drop()),
        tl.LayerNorm(),
        BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
        tl.Dense(shorten_factor * d_embedding),
        tl.Fn(
            lambda x: np.reshape(  # Prolong back.
                x, (x.shape[0], x.shape[1] * shorten_factor, -1)),
            n_out=1),
        tl.Concatenate(),  # Concatenate with just the embeddings.
        tl.CausalConv(d_embedding),
        tl.Relu(),
        tl.SRU(d_embedding),  # One RNN layer for conditional dependence.
        tl.Dense(vocab_size),
        tl.LogSoftmax())
Example #20
0
 def test_default_name(self):
     layer = tl.Parallel(tl.Dup(), tl.Dup())
     self.assertIn('Parallel', str(layer))
Example #21
0
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads,
                 n_attention_chunks, attention_type, dropout, share_qk,
                 ff_activation, ff_use_sru, ff_chunk_size, mode):
    """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_heads: int: number of attention heads
    n_attention_chunks: int: number of chunks for attention
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    share_qk: string, whether to share queries and keys
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    if share_qk:
        pre_attention = [
            Chunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
            tl.LayerNorm(),
            tl.Dup(),
            tl.Parallel(
                tl.ComputeAttentionHeads(n_heads=n_heads,
                                         d_head=d_attention_key),
                tl.ComputeAttentionHeads(n_heads=n_heads,
                                         d_head=d_attention_value),
            ),
            tl.Dup(),
        ]
    else:
        pre_attention = [
            Chunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
            tl.LayerNorm(),
            tl.Dup(),
            tl.Dup(),
            tl.Parallel(
                tl.ComputeAttentionHeads(n_heads=n_heads,
                                         d_head=d_attention_key),
                tl.ComputeAttentionHeads(n_heads=n_heads,
                                         d_head=d_attention_key),
                tl.ComputeAttentionHeads(n_heads=n_heads,
                                         d_head=d_attention_value),
            ),
        ]

    attention = attention_type(mode=mode)

    # ReversibleAttentionHalfResidual requires that post_attention be linear in
    # its input (so the backward pass can be computed without knowing the input)
    post_attention = [
        tl.ComputeAttentionOutput(n_heads=n_heads, d_model=d_model),
        Unchunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
        BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
    ]

    if ff_use_sru:
        feed_forward = [tl.SRU(d_model) for _ in range(ff_use_sru)]
    else:
        feed_forward = [
            ChunkedFeedForward(d_model, d_ff, dropout, ff_activation,
                               ff_chunk_size, mode)
        ]

    return [
        ReversibleAttentionHalfResidual(pre_attention, attention,
                                        post_attention),
        tl.ReversibleSwap(),
        ReversibleHalfResidual(feed_forward),
        tl.ReversibleSwap(),
    ]
Example #22
0
 def test_custom_name(self):
     layer = tl.Parallel(tl.Dup(), tl.Dup(), name='DupDup')
     self.assertIn('DupDup', str(layer))
Example #23
0
def _FunnelBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode,
                 ff_activation, pool_layer, pool_size, strides, separate_cls):
    """Internal funnel block. Returns a list of layers implementing it.

  The input is an activation tensor.

  Args:
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each block.
    n_heads: Number of attention heads.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within a block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
        a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: If `'train'`, each block will include dropout; else, it will
        pass all values through unaltered.
    ff_activation: Type of activation function at the end of each block; must
        be an activation-type subclass of `Layer`.
    pool_layer: Type of pooling layer used for downsampling;
        should be `tl.AvgPool` or `tl.MaxPool`.
    pool_size: Shape of window that gets reduced to a single vector value.
        If the layer inputs are :math:`n`-dimensional arrays, then `pool_size`
        must be a tuple of length :math:`n-2`.
    strides: Offsets from the location of one window to the locations of
        neighboring windows along each axis. If specified, must be a tuple of
        the same length as `pool_size`. If None, then offsets of 1 along each
        window axis, :math:`(1, ..., 1)`, will be used.
    separate_cls: If `True`, pooling in funnel blocks is not applied to
          embeddings of the first token (`cls` from BERT paper).
  Returns:
      A list of layers that maps (activations, mask) to (activations', mask).
  """
    pooling = PoolLayer(pool_layer, pool_size, strides, separate_cls)
    mask_pooling = MaskPool(pool_size, strides, separate_cls)

    attention = tl.AttentionQKV(d_model,
                                n_heads=n_heads,
                                dropout=dropout,
                                mode=mode)
    hidden_dropout = tl.Dropout(rate=dropout,
                                shared_axes=dropout_shared_axes,
                                mode=mode)

    feed_forward = _FeedForwardBlock(d_model, d_ff, dropout,
                                     dropout_shared_axes, mode, ff_activation)

    return [  # h, mask
        tl.LayerNorm(),  # h, mask
        tl.Branch(pooling, None),  # h', h, mask
        tl.Residual(
            tl.Select([0, 1, 1, 2]),  # h', h, h, mask
            attention,  # attn, mask
            tl.Parallel(None, mask_pooling),  # attn, mask'
            hidden_dropout  # attn, mask'
        ),  # funnel_activations, mask'
        tl.Residual(
            tl.LayerNorm(),
            feed_forward,
            hidden_dropout,
        )
    ]
Example #24
0
 def test_weights(self):
     model = tl.Parallel(tl.Dense(3), tl.Dense(5))
     self.assertIsInstance(model.weights, tuple)
     self.assertLen(model.weights, 2)
Example #25
0
def Transformer(input_vocab_size,
                output_vocab_size=None,
                d_model=512,
                d_ff=2048,
                n_encoder_layers=6,
                n_decoder_layers=6,
                n_heads=8,
                dropout=0.1,
                max_len=2048,
                mode='train',
                ff_activation=tl.Relu):
    """Returns a Transformer model.

  This model expects an input pair: target, source.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A Transformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    def PositionalEmbedder(vocab_size):  # tokens --> vectors
        return [
            tl.Embedding(d_model, vocab_size),
            tl.Dropout(rate=dropout, mode=mode),
            tl.PositionalEncoding(max_len=max_len),
        ]

    def EncoderBlocks(n_blocks):  # vectors masks --> vectors masks
        return [
            _EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode,
                          ff_activation) for i in range(n_blocks)
        ]

    def EncoderDecoderBlocks(n_blocks):  # vectors masks --> vectors masks
        return [
            _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, i, mode,
                                 ff_activation) for i in range(n_blocks)
        ]

    in_embed = PositionalEmbedder(input_vocab_size)
    out_embed = (in_embed if output_vocab_size is None else
                 PositionalEmbedder(output_vocab_size))
    if output_vocab_size is None:
        output_vocab_size = input_vocab_size

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 1, 1]),  # tok_e tok_d tok_d

        # Encode.
        tl.Branch(in_embed, tl.PaddingMask()),  # vec_e masks ..... .....
        EncoderBlocks(n_encoder_layers),  # vec_d masks ..... .....
        tl.LayerNorm(),  # vec_e ..... ..... .....

        # Decode.
        tl.Select([2, 1, 0]),  # tok_d masks vec_e .....
        tl.ShiftRight(),  # tok_d ..... ..... .....
        out_embed,  # vec_d ..... ..... .....
        tl.Branch([], tl.EncoderDecoderMask()),  # vec_d masks ..... .....
        EncoderDecoderBlocks(n_decoder_layers),  # vec_d masks ..... .....
        tl.LayerNorm(),  # vec_d ..... ..... .....

        # Map to output vocab.
        tl.Parallel([], tl.Drop(), tl.Drop()),  # vec_d tok_d
        tl.Dense(output_vocab_size),  # vec_d .....
        tl.LogSoftmax(),  # vec_d .....
    )
Example #26
0
 def test_shared_weights_nested(self):
     layer = tl.Dense(5)
     model = tl.Parallel([layer, tl.Dense(2)], [layer, tl.Dense(2)])
     sample_input = (np.array([1, 2, 3, 4, 5]), np.array([1, 2, 3, 4, 5]))
     weights, _ = model.init(shapes.signature(sample_input))
     self.assertIs(weights[1][0], tl.GET_WEIGHTS_FROM_CACHE)
Example #27
0
def LatentTransformer(input_vocab_size,
                      output_vocab_size=None,
                      d_model=512,
                      d_ff=2048,
                      n_encoder_layers=6,
                      n_decoder_layers=6,
                      n_heads=8,
                      dropout=0.1,
                      dropout_shared_axes=None,
                      max_len=2048,
                      mode='train',
                      ff_activation=tl.Relu,
                      axial_pos_shape=None,
                      d_axial_pos_embs=None):
    """Returns a Transformer model.

  This model expects an input pair: target, source.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    dropout_shared_axes: axes on which to share dropout mask
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.

  Returns:
    A Transformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    in_encoder, out_encoder, output_vocab_size = (
        ct.EmbeddingAndPositionalEncodings(input_vocab_size,
                                           d_model,
                                           mode,
                                           dropout,
                                           dropout_shared_axes,
                                           max_len,
                                           output_vocab_size=output_vocab_size,
                                           axial_pos_shape=axial_pos_shape,
                                           d_axial_pos_embs=d_axial_pos_embs))

    encoder_blocks = [
        _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                      mode, ff_activation) for i in range(n_encoder_layers)
    ]

    encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm())
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    decoder_blocks = [
        _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                      mode, ff_activation) for i in range(n_decoder_layers)
    ]

    compress_seq = tl.Serial(
        # input:                            #   tok
        tl.Branch([], tl.PaddingMask()),  #   tok mask
        encoder,  #   vec mask
        PickFirst(),  # vec_f mask
        tl.Select([0], n_in=2))  # vec_f

    latent_transition = tl.Serial(
        tl.Parallel([tl.Dense(d_model), tl.Relu()],
                    [tl.Dense(d_model), tl.Relu()]), tl.Add(),
        tl.Residual(
            tl.LayerNorm(),
            tl.Dense(d_model),
            tl.Relu(),
            tl.Dropout(rate=dropout, mode=mode),
            tl.Dense(d_model),
        ))

    pred_valid = tl.Serial(tl.Dense(2), Squeeze(1))

    embed_tgt = tl.Serial(
        # Input                             #  tok_d
        DropLast(mode=mode),  # stok_d
        out_encoder,  # svec_d
    )

    decode_seq = tl.Serial(
        # Input:                                 #  vec_e  tok_d
        tl.Select([1, 0, 1]),  #  tok_d  vec_e tok_d
        tl.Parallel(embed_tgt, [], DropFirst()),  # svec_d  vec_e tok_d'
        ConcatDeEntoEnDe(),  # vec_ed tok_d'
        # Decoder blocks with causal attention
        decoder_blocks,  # vec_ed tok_d'
        tl.LayerNorm(),  # vec_ed tok_d'
        DropFirst(),  #  vec_d tok_d'
        # Map to output vocab.
        tl.Dense(output_vocab_size),  # pred_d tok_d'
    )

    # compress_seq: n_in 1 n_out 1: add mask, encode, pick last hidden
    # latent_transition: n_in 2 n_out 1: s, a -> s_1
    # pred_valid: n_in 1 n_out 1: s_1 -> pred_v
    # decode_seq: n_in 2 n_out 2: copy target, shift right, decode, output

    return tl.Serial(
        #       0      1      2      3      4     5      6 7 8
        # Input:                                #   tok_s  tok_a tok_s1      r      v
        tl.Select([0, 1, 2, 0, 1, 3,
                   4]),  #   tok_s  tok_a tok_s1  tok_s  tok_a     r      v

        # Encode.
        tl.Parallel(
            compress_seq,
            compress_seq),  #   vec_s  vec_a tok_s1  tok_s  tok_a     r      v
        tl.Branch(latent_transition, [], tl.Select(
            [1],
            n_in=2)),  #  vec_s1  vec_s  vec_a tok_s1  tok_s tok_a      r v
        tl.Branch(pred_valid,
                  []),  #  pred_v vec_s1  vec_s  vec_a tok_s1 tok_s  tok_a r v
        # Decode.
        tl.Select([1, 4, 2, 5, 3, 6, 0, 8,
                   7]),  #  vec_s1 tok_s1  vec_s  tok_s  vec_a tok_a pred_v v r
        tl.Parallel(decode_seq, decode_seq, decode_seq
                    ),  # pred_s1 tok_s1 pred_s  tok_s pred_a tok_a pred_v v r
    )
Example #28
0
def EncoderBlock(d_model,
                 d_ff,
                 n_heads,
                 attention_type,
                 dropout,
                 ff_activation,
                 ff_dropout,
                 ff_use_sru=0,
                 ff_chunk_size=0,
                 ff_sparsity=0,
                 attention_chunk_size=0,
                 center_layernorm=True,
                 use_bfloat16=False,
                 use_two_swaps_per_block=True,
                 mode='train'):
    """Returns a list of layers that implements a Reformer encoder block.

  The input to the layer is a pair, (activations, mask), where the mask was
  created from the original source tokens to prevent attending to the padding
  part of the input.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: the dropout rate in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    attention_chunk_size: int, if > 0 run attention chunked at this size
    center_layernorm: whether to use centering in LayerNorm (default) or if
      to skip it, which is known as RMS normalization.
    use_bfloat16: whether to use bfloat16 for weights (default: False)
    use_two_swaps_per_block: bool, if True use two reversible swaps in Encoder
      block, otherwise use only one swap.
    mode: str: 'train' or 'eval'

  Returns:
    A list of layers that maps (activations, mask) to (activations, mask).
  """
    if mode == 'predict':
        # Mode 'predict' means that the decoder should be run one token at a time.
        # The encoder only ever runs over full sequences, which is why it's switched
        # to 'eval' mode instead.
        mode = 'eval'

    def _Attn():
        return ct.ApplyAttentionLayer(
            attention_type=attention_type,
            d_model=d_model,
            n_heads=n_heads,
            d_qk=d_model // n_heads,
            d_v=d_model // n_heads,
            masked=True,
            causal=False,
            attention_dropout=dropout,
            output_dropout=dropout,
            attention_chunk_size=attention_chunk_size,
            mode=mode)

    def _FF():
        return ct.FeedForwardWithOptions(d_model, d_ff, dropout, [-2],
                                         ff_activation, ff_dropout,
                                         ff_chunk_size, ff_use_sru,
                                         ff_sparsity, center_layernorm, mode,
                                         use_bfloat16)

    # TODO(lukaszkaiser): refactor efficient attention layers to unify the API
    # If we're using standard attention, we need to pass reshaped mask and not
    # return the mask to be compatible with the EfficientAttention API.
    attention = _Attn()
    if attention.n_out == 2:
        attention = tl.Serial(tl.Parallel([], _InsertAxes12()), attention,
                              tl.Select([0], n_in=2))

    def _attention_half_residual():
        return [
            tl.ReversibleHalfResidual(
                tl.LayerNorm(center=center_layernorm),
                attention_layer=attention,
                name='ReversibleHalfResidualEncoderAttn'),
            tl.ReversibleSwap()
        ]

    def _feed_forward():
        layers = [
            tl.ReversibleHalfResidual(_FF(),
                                      name='ReversibleHalfResidualEncoderFF')
        ]
        if use_two_swaps_per_block:
            layers.append(tl.ReversibleSwap())
        return layers

    return _attention_half_residual() + _feed_forward()