Example #1
0
    def test_train_eval_predict_sm3(self):
        with self.tmp_dir() as output_dir:
            # Prepare model and inputs
            n_classes = 4
            train_steps = 2
            eval_steps = 2
            model_fn = functools.partial(models.MLP,
                                         d_hidden=16,
                                         n_output_classes=n_classes)
            inputs = lambda _: test_inputs(n_classes)

            # Train and evaluate
            state = trax.train(output_dir,
                               model=model_fn,
                               inputs=inputs,
                               train_steps=train_steps,
                               eval_steps=eval_steps,
                               optimizer=trax_opt.SM3)

            # Assert total train steps
            self.assertEqual(train_steps, state.step)

            # Assert 2 evaluations ran
            train_acc = state.history.get("train", "metrics/accuracy")
            eval_acc = state.history.get("eval", "metrics/accuracy")
            self.assertEqual(len(train_acc), len(eval_acc))
            self.assertEqual(2, len(eval_acc))

            # Predict with final params
            inputs = inputs(1).train_stream()
            model = layers.Serial(model_fn())
            model(next(inputs)[0], state.params[0])
Example #2
0
 def Generator(encoded_target):
   return layers.Serial(
       encoded_target,
       layers.Dense(target_vocab_size,
                    kernel_initializer=layers.XavierUniformInitializer()),
       layers.LogSoftmax
   )
Example #3
0
def SumLearnedPick(positions):
  """Get a pair (vec, pos) and pick new pos."""
  succ_keys = positions[:-1, :]
  succ_values = positions[1:, :]
  subtract_1_keys = positions[1:, :]
  subtract_1_values = positions[:-1, :]
  l = int(positions.shape[0]) // 2
  add_keys = np.array([np.concatenate([positions[i, :], positions[j, :]])
                       for i in range(l) for j in range(l)])
  add_values = np.array([positions[i + j, :]
                         for i in range(l) for j in range(l)])
  # TODO(lukaszkaiser): try this below: "for j in range(i) for i in range(2*l)"
  sub_keys = np.array([np.concatenate([positions[i, :], positions[j, :]])
                       for j in range(l) for i in range(l)])
  sub_values = np.array([positions[max(i - j, 0), :]
                         for j in range(l) for i in range(l)])
  return tl.Serial(
      tl.Branch(
          LearnedQP(),
          LearnedQP(keys=succ_keys, values=succ_values),
          LearnedQP(keys=subtract_1_keys, values=subtract_1_values),
          LearnedQP(keys=add_keys, values=add_values, binary=True),
          LearnedQP(keys=sub_keys, values=sub_values, binary=True),
      ),
      Unnest(),
      SoftmaxBranches(n_branches=5)
  )
Example #4
0
def PositionLookupTransformerLM(vocab_size=128,
                                d_feature=256,
                                d_feedforward=512,
                                n_layers=3,
                                n_heads=4,
                                dropout=0.1,
                                max_len=100,
                                mode='train'):
  """Transformer language model (only uses the decoder part of Transformer).

  Args:
    vocab_size: int: vocab size
    d_feature: int:  depth of embedding
    d_feedforward: int: depth of feed-forward layer
    n_layers: int: number of layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: maximal length
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
  positions = _POSITIONS[:max_len, :]
  return tl.Serial([
      tl.ShiftRight(),
      tl.Embedding(d_feature, vocab_size),
      tl.Dropout(rate=dropout, mode=mode),
      NewPositionalEncoding(positions=positions),
      [DecoderLayer(positions, d_feature, d_feedforward, n_heads, dropout, mode)
       for _ in range(n_layers)],
      PreservePosition(tl.LayerNorm()),
      tl.Dense(vocab_size),
      tl.LogSoftmax()
  ])
Example #5
0
    def test_train_eval_predict_sm3(self, backend_name):
        if xla_bridge.device_count() > 1 and backend_name == "tf":
            self.skipTest(
                "tf-numpy backend doesn't support multi-devices yet.")
        with backend.use_backend(backend_name), self.tmp_dir() as output_dir:
            # Prepare model and inputs
            n_classes = 4
            train_steps = 2
            eval_steps = 2
            model_fn = functools.partial(models.MLP,
                                         d_hidden=16,
                                         n_output_classes=n_classes)
            inputs = lambda _: test_inputs(n_classes)

            # Train and evaluate
            state = trax.train(output_dir,
                               model=model_fn,
                               inputs=inputs,
                               train_steps=train_steps,
                               eval_steps=eval_steps,
                               optimizer=trax_opt.SM3)

            # Assert total train steps
            self.assertEqual(train_steps, state.step)

            # Assert 2 evaluations ran
            train_acc = state.history.get("train", "metrics/accuracy")
            eval_acc = state.history.get("eval", "metrics/accuracy")
            self.assertEqual(len(train_acc), len(eval_acc))
            self.assertLen(eval_acc, 2)

            # Predict with final params
            inputs = inputs(1).train_stream()
            model = layers.Serial(model_fn())
            model(next(inputs)[0], params=state.opt_state.params)
def DecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode):
    """Transformer decoder layer.

  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    return layers.Serial(
        layers.Residual(  # Self-attention block.
            layers.LayerNorm(),
            layers.Branch(),
            layers.Parallel(
                layers.Identity(),  # activation for (q, k, v)
                layers.CausalMask(axis=-2)),  # attention mask
            layers.MultiHeadedAttention(feature_depth,
                                        num_heads=num_heads,
                                        dropout=dropout,
                                        mode=mode),
            layers.Dropout(rate=dropout, mode=mode)),
        ResidualFeedForward(feature_depth,
                            feedforward_depth,
                            dropout,
                            mode=mode))
Example #7
0
def DecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode):
    """Transformer decoder layer.

  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    return tl.Serial(
        tl.Residual(  # Self-attention block.
            tl.LayerNorm(),
            tl.Branch(tl.Copy(), tl.CausalMask(axis=-2)),  # Create mask.
            tl.MultiHeadedAttention(feature_depth,
                                    num_heads=num_heads,
                                    dropout=dropout,
                                    mode=mode),
            tl.Select(0),  # Drop the mask.
            tl.Dropout(rate=dropout, mode=mode)),
        ResidualFeedForward(feature_depth,
                            feedforward_depth,
                            dropout,
                            mode=mode))
Example #8
0
def EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode):
    """Transformer encoder layer.

  The input to the encoder is a pair (embedded source, mask) where
  the mask is created from the original source to prevent attending
  to the padding part of the input.

  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer, returning a pair (actiavtions, mask).
  """
    return tl.Serial(
        tl.Residual(  # Attention block here.
            tl.Parallel(tl.LayerNorm(), tl.Copy()),
            tl.MultiHeadedAttention(feature_depth,
                                    num_heads=num_heads,
                                    dropout=dropout,
                                    mode=mode),
            tl.Parallel(tl.Dropout(rate=dropout, mode=mode), tl.Copy())),
        tl.Parallel(
            ResidualFeedForward(feature_depth,
                                feedforward_depth,
                                dropout,
                                mode=mode),
            tl.Div(
                divisor=2.0)  # Mask added to itself in the residual, divide.
        ))
Example #9
0
def WideResnet(num_blocks=3, hidden_size=64, num_output_classes=10,
               mode='train'):
  """WideResnet from https://arxiv.org/pdf/1605.07146.pdf.

  Args:
    num_blocks: int, number of blocks in a group.
    hidden_size: the size of the first hidden layer (multiplied later).
    num_output_classes: int, number of classes to distinguish.
    mode: is it training or eval.

  Returns:
    The WideResnet model with given layer and output sizes.
  """
  del mode
  return tl.Serial(
      tl.Conv(hidden_size, (3, 3), padding='SAME'),
      WideResnetGroup(num_blocks, hidden_size),
      WideResnetGroup(num_blocks, hidden_size * 2, (2, 2)),
      WideResnetGroup(num_blocks, hidden_size * 4, (2, 2)),
      tl.BatchNorm(),
      tl.Relu(),
      tl.AvgPool(pool_size=(8, 8)),
      tl.Flatten(),
      tl.Dense(num_output_classes),
      tl.LogSoftmax()
  )
Example #10
0
def ChunkedDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout,
                        chunk_selector, mode):
    """Transformer decoder layer operating on chunks.

  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    chunk_selector: a function from chunk number to list of chunks to attend.
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    return layers.Serial(
        layers.Residual(  # Self-attention block.
            layers.Map(layers.LayerNorm()),
            layers.ChunkedCausalMultiHeadedAttention(
                feature_depth,
                num_heads=num_heads,
                dropout=dropout,
                chunk_selector=chunk_selector,
                mode=mode),
            layers.Map(layers.Dropout(rate=dropout, mode=mode)),
        ),
        layers.Map(
            ResidualFeedForward(feature_depth,
                                feedforward_depth,
                                dropout,
                                mode=mode)))
Example #11
0
def _jit_predict_fn(model_predict, metric_fn, n_devices, jit=True):
    """Returns a JIT-compiled predict function (unless jit=False)."""
    model_predict = layers.Serial([model_predict, metric_fn])

    if n_devices == 1:
        return backend.jit(model_predict) if jit else model_predict

    # Multi-devices, pmap and run.
    @functools.partial(backend.pmap, axis_name="batch")
    def mapped_predict(x, params, state, rng):
        return model_predict(x, params=params, state=state, rng=rng)

    def predict(x, params=(), state=(), rng=None):
        """Predict function jited and parallelized as requested."""
        pred = mapped_predict(reshape_by_device(x, n_devices), params, state,
                              jax_random.split(rng, n_devices))

        # Need to reduce the [device, per-device-batch, ...] tensors back to
        # a [batch, ...] tensor. The tensors may be nested.
        def combine(x):
            if len(x.shape) > 1:
                batch_size = x.shape[0] * x.shape[1]
                return np.reshape(x, [batch_size] + list(x.shape[2:]))
            # TODO(lukaszkaiser): is returning averages for scalars the right choice?
            # If it is only scalar, return the average.
            return np.mean(x, axis=0)

        return layers.nested_map(pred, combine)

    return predict
Example #12
0
def policy_and_value_net(rng_key,
                         batch_observations_shape,
                         n_actions,
                         bottom_layers_fn=(),
                         two_towers=True):
    """A policy and value net function."""

    # Layers.

    # Now, with the current logits, one head computes action probabilities and the
    # other computes the value function.
    # NOTE: The LogSoftmax instead of the Softmax because of numerical stability.

    if two_towers:
        net = tl.Branch(
            [bottom_layers_fn(),
             tl.Dense(n_actions),
             tl.LogSoftmax()],
            [bottom_layers_fn(), tl.Dense(1)])
    else:
        net = tl.Serial(
            bottom_layers_fn(),
            tl.Branch(
                [tl.Dense(n_actions), tl.LogSoftmax()], [tl.Dense(1)]))
    return net.initialize(batch_observations_shape, rng_key), net
    def __init__(self, pre_attention, attention, post_attention):
        self.pre_attention = tl.Serial([
            # (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
            tl.Parallel([], tl.Dup()),
            tl.Swap(),
            tl.Parallel(pre_attention, [], []),
        ])
        assert hasattr(attention, 'forward_and_vjp')
        self.attention = ApplyAttentionWrapper(attention)
        self.post_attention = tl.Parallel(post_attention, [], [])

        layers = [
            self.pre_attention,
            self.attention,
            self.post_attention,
            tl.Parallel(tl.Add(), []),
        ]
        super(ReversibleAttentionHalfResidual, self).__init__(layers)

        self.subtract_top = tl.Parallel(tl.SubtractTop(), [])
        self.reverse_layers = [
            self.pre_attention,
            self.attention,
            self.post_attention,
            self.subtract_top,
        ]
def ChunkedDecoderLayer(d_feature, d_feedforward, n_heads, dropout,
                        chunk_selector, mode):
    """Transformer decoder layer operating on chunks.

  Args:
    d_feature: int:  depth of embedding
    d_feedforward: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    chunk_selector: a function from chunk number to list of chunks to attend.
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    return tl.Serial(
        Residual(  # Self-attention block.
            tl.Map(tl.LayerNorm()),
            ChunkedCausalMultiHeadedAttention(d_feature,
                                              n_heads=n_heads,
                                              dropout=dropout,
                                              chunk_selector=chunk_selector,
                                              mode=mode),
            tl.Map(tl.Dropout(rate=dropout, mode=mode)),
        ),
        tl.Map(
            ResidualFeedForward(d_feature, d_feedforward, dropout, mode=mode)))
Example #15
0
def PreservePosition(layer):
  """Execute layer without position but preserve it in parallel."""
  return tl.Serial(
      CutAtPosition(),
      layer,
      tl.Concatenate(n_items=2)
  )
Example #16
0
def PreservePosition(layer):
  """Execute layer without position but preserve it in parallel."""
  return tl.Serial(
      CutPosition(),
      layer,
      ConcatenateN()
  )
Example #17
0
def AttentionPosition(positions,
                      d_model,
                      n_heads=8,
                      dropout=0.0,
                      mode='train'):
    """Transformer-style multi-headed attention."""
    return tl.Serial(
        tl.Dup(),
        tl.Dup(),
        tl.Parallel(
            ApplyAndQueryPositions(
                tl.Dense(d_model),
                pos=[SumLearnedPick(positions) for _ in range(n_heads)]),
            PreservePosition(tl.Dense(d_model)),
            PreservePosition(tl.Dense(d_model)),
        ),
        tl.Parallel(
            CopyHeadsPos(h=n_heads),
            MixHeadsPos(h=n_heads),
            MixHeadsPos(h=n_heads),
        ),
        tl.PureAttention(d_model=d_model,
                         n_heads=n_heads,
                         dropout=dropout,
                         mode=mode),
        tl.Parallel([], tl.Drop()),  # Drop the mask.
        CombineHeadsPos(h=n_heads),
        PreservePosition(tl.Dense(d_model)),
    )
Example #18
0
    def Decoder(memory, target, target_mask, memory_mask):
        """Transformer decoder stack.

    Args:
      memory: layer variable: encoded source sequences
      target: layer variable: raw target sequences
      target_mask: layer variable: self-attention mask
      memory_mask: layer variable: memory attention mask

    Returns:
      Layer variable that outputs encoded source.
    """
        decoder_layer = layers.Serial(
            # target attends to self
            layers.Residual(
                layers.LayerNorm(),
                layers.Branch(size=4),
                layers.Parallel(
                    layers.Identity(),  # query
                    layers.Identity(),  # key
                    layers.Identity(),  # value
                    target_mask),  # attention mask
                multi_attention,
                layers.Dropout(dropout, mode=mode)),
            # target attends to encoded source
            layers.Residual(
                layers.LayerNorm(),
                layers.Branch(size=4),
                layers.Parallel(
                    layers.Identity(),  # query
                    memory,  # key
                    memory,  # value
                    memory_mask),  # attention mask
                multi_attention,
                layers.Dropout(dropout, mode=mode)),
            # feed-forward
            ResidualFeedForward(feature_depth,
                                feedforward_depth,
                                dropout,
                                mode=mode))
        return layers.Serial(
            target,
            target_embedding_layer,
            layers.repeat(decoder_layer, num_layers),
            layers.LayerNorm(),
        )
Example #19
0
 def test_transformer_lm_forward_shape(self):
   """Run the Transformer LM forward and check output shape."""
   vocab_size = 16
   input_shape = [3, 5]
   model = transformer.TransformerLM(
       vocab_size, d_feature=32, d_feedforward=64, n_layers=2, n_heads=2)
   final_shape = tl.check_shape_agreement(
       tl.Serial(model), tuple(input_shape), integer_inputs=True)
   self.assertEqual(tuple(input_shape + [vocab_size]), final_shape)
Example #20
0
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
    """WideResnet convolutational block."""
    main = tl.Serial(tl.BatchNorm(), tl.Relu(),
                     tl.Conv(channels, (3, 3), strides, padding='SAME'),
                     tl.BatchNorm(), tl.Relu(),
                     tl.Conv(channels, (3, 3), padding='SAME'))
    shortcut = tl.Copy() if not channel_mismatch else tl.Conv(
        channels, (3, 3), strides, padding='SAME')
    return tl.Residual(main, shortcut=shortcut)
Example #21
0
def IdentityBlock(kernel_size, filters):
  """ResNet identical size block."""
  ks = kernel_size
  filters1, filters2, filters3 = filters
  main = tl.Serial(
      tl.Conv(filters1, (1, 1)),
      tl.BatchNorm(),
      tl.Relu(),
      tl.Conv(filters2, (ks, ks), padding='SAME'),
      tl.BatchNorm(),
      tl.Relu(),
      tl.Conv(filters3, (1, 1)),
      tl.BatchNorm()
  )
  return tl.Serial(
      tl.Residual(main),
      tl.Relu()
  )
Example #22
0
def TransformerEncoder(vocab_size,
                       num_classes=10,
                       feature_depth=512,
                       feedforward_depth=2048,
                       num_layers=6,
                       num_heads=8,
                       dropout=0.1,
                       max_len=2048,
                       mode='train'):
  """Transformer encoder.

  Args:
    vocab_size: int: vocab size
    num_classes: how many classes on output
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_layers: int: number of encoder/decoder layers
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    the Transformer encoder layer.
  """
  input_embedding = layers.Serial(
      layers.Embedding(feature_depth, vocab_size),
      layers.Dropout(rate=dropout, mode=mode),
      layers.PositionalEncoding(max_len=max_len)
  )
  return layers.Serial(
      layers.Branch(),  # Branch input to create embedding and mask.
      layers.Parallel(input_embedding, layers.PaddingMask()),
      layers.Serial(*[EncoderLayer(feature_depth, feedforward_depth, num_heads,
                                   dropout, mode)
                      for _ in range(num_layers)]),
      layers.FirstBranch(),  # Drop the mask.
      layers.LayerNorm(),
      layers.Mean(axis=1),  # Average on length.
      layers.Dense(num_classes),
      layers.LogSoftmax()
  )
Example #23
0
 def __init__(self, layer, check_shapes=True):
   super(Map, self).__init__()
   if layer is None or isinstance(layer, (list, tuple)):
     layer = tl.Serial(layer)
   self._layer = layer
   # Generally a Map should be applied to lists where all elements have
   # the same shape -- because self._layer will only be initialized once
   # and it could have different parameters for different shapes. But there
   # are valid cases -- e.g., when self._layer has no parameters -- where we
   # can apply Map to different shapes -- set check_shapes=False in such cases.
   self._check_shapes = check_shapes
def ChunkedCausalMultiHeadedAttention(d_feature,
                                      n_heads=8,
                                      dropout=0.0,
                                      chunk_selector=None,
                                      mode='train'):
    """Transformer-style causal multi-headed attention operating on chunks.

  Accepts inputs that are a list of chunks and applies causal attention.

  Args:
    d_feature: int:  depth of embedding
    n_heads: int: number of attention heads
    dropout: float: dropout rate
    chunk_selector: a function from chunk number to list of chunks to attend.
    mode: str: 'train' or 'eval'

  Returns:
    Multi-headed self-attention layer.
  """
    prepare_attention_input = tl.Serial(
        tl.Branch(
            tl.Branch(  # q = k = v = first input
                tl.NoOp(), tl.NoOp(), tl.NoOp()),
            tl.CausalMask(axis=-2),
        ),
        tl.Parallel(
            tl.Parallel(
                tl.Dense(d_feature),
                tl.Dense(d_feature),
                tl.Dense(d_feature),
            ), tl.NoOp()))
    return tl.Serial(
        tl.Map(prepare_attention_input),
        ChunkedAttentionSelector(selector=chunk_selector),  # pylint: disable=no-value-for-parameter
        tl.Map(tl.PureMultiHeadedAttention(d_feature=d_feature,
                                           n_heads=n_heads,
                                           dropout=dropout,
                                           mode=mode),
               check_shapes=False),
        tl.Map(tl.Select(0), check_shapes=False),  # drop masks
        tl.Map(tl.Dense(d_feature)))
Example #25
0
def policy_and_value_net(rng_key,
                         batch_observations_shape,
                         num_actions,
                         bottom_layers=None):
  """A policy and value net function."""

  # Layers.
  cur_layers = []
  if bottom_layers is not None:
    cur_layers.extend(bottom_layers)

  # Now, with the current logits, one head computes action probabilities and the
  # other computes the value function.
  # NOTE: The LogSoftmax instead of the Softmax because of numerical stability.
  cur_layers.extend([
      layers.Branch(
          layers.Serial(layers.Dense(num_actions), layers.LogSoftmax()),
          layers.Dense(1))
  ])
  net = layers.Serial(*cur_layers)
  return net.initialize(batch_observations_shape, rng_key), net
Example #26
0
def MLP(num_hidden_layers=2,
        hidden_size=512,
        activation_fn=tl.Relu,
        num_output_classes=10,
        mode="train"):
    """Multi-layer feed-forward neural network with non-linear activations."""
    del mode
    cur_layers = [tl.Flatten()]
    for _ in range(num_hidden_layers):
        cur_layers += [tl.Dense(hidden_size), activation_fn()]
    cur_layers += [tl.Dense(num_output_classes), tl.LogSoftmax()]
    return tl.Serial(*cur_layers)
Example #27
0
  def __init__(self, residual_layers):
    self.compute_residual = tl.Serial([
        # TODO(jonni): Rewrite without using Select.
        tl.Select(inputs=('x1_or_y1', 'x2'), output=('x2', 'x1_or_y1', 'x2')),
        tl.Parallel(residual_layers, [], []),
    ])

    layers = [self.compute_residual, tl.Add()]
    super(ReversibleHalfResidual, self).__init__(layers)

    self.subtract_top = tl.SubtractTop()
    self.reverse_layers = [self.compute_residual, self.subtract_top]
Example #28
0
    def __init__(self, residual_layers):
        self.compute_residual = tl.Serial([
            # (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
            tl.Parallel([], tl.Dup()),
            tl.Swap(),
            tl.Parallel(residual_layers, [], []),
        ])

        layers = [self.compute_residual, tl.Parallel(tl.Add(), [])]
        super(ReversibleHalfResidual, self).__init__(layers)

        self.subtract_top = tl.Parallel(tl.SubtractTop(), [])
        self.reverse_layers = [self.compute_residual, self.subtract_top]
Example #29
0
def ConvBlock(kernel_size, filters, strides):
  """ResNet convolutional striding block."""
  ks = kernel_size
  filters1, filters2, filters3 = filters
  main = tl.Serial(
      tl.Conv(filters1, (1, 1), strides),
      tl.BatchNorm(),
      tl.Relu(),
      tl.Conv(filters2, (ks, ks), padding='SAME'),
      tl.BatchNorm(),
      tl.Relu(),
      tl.Conv(filters3, (1, 1)),
      tl.BatchNorm()
  )
  shortcut = tl.Serial(
      tl.Conv(filters3, (1, 1), strides),
      tl.BatchNorm()
  )
  return tl.Serial(
      tl.Residual(main, shortcut=shortcut),
      tl.Relu()
  )
Example #30
0
def EncoderLayer(feature_depth,
                 feedforward_depth,
                 num_heads,
                 dropout,
                 mode):
  """Transformer encoder layer.

  The input to the encoder is a pair (embedded source, mask) where
  the mask is created from the original source to prevent attending
  to the padding part of the input.

  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer, returning a pair (actiavtions, mask).
  """
  # The encoder block expects (activation, mask) as input and returns
  # the new activations only, we add the mask back to output next.
  encoder_block = layers.Serial(
      layers.Residual(  # Attention block here.
          layers.Parallel(layers.LayerNorm(), layers.Identity()),
          layers.MultiHeadedAttention(feature_depth, num_heads=num_heads,
                                      dropout=dropout, mode=mode),
          layers.Dropout(rate=dropout, mode=mode),
          shortcut=layers.FirstBranch()
      ),
      ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode)
  )
  # Now we add the mask back.
  return layers.Serial(
      layers.Reorder(output=((0, 1), 1)),  # (x, mask) --> ((x, mask), mask)
      layers.Parallel(encoder_block, layers.Identity())
  )