Example #1
0
 def test_chunk_uneven_numbers(self):
     layer = tl.Dense(4)
     x = np.array([[1, 2, 3], [4, 5, 6]])
     layer.init(x)
     y = layer(x)
     z = tl.Chunk(layer, 3)(x)  # By default it should just pass
     self.assertLess(np.sum((y - z)**2), 1e-5)  # y == z upto numerics
     chunk_with_test = tl.Chunk(layer, 3, pass_unchunkable=False)
     self.assertRaises(tl.LayerError, lambda: chunk_with_test(x))
Example #2
0
 def test_chunk(self):
     layer = tl.Dense(4)
     x = np.array([[1, 2, 3], [4, 5, 6]])
     layer.init(x)
     y = layer(x)
     z = tl.Chunk(layer, 1)(x)
     self.assertLess(np.sum((y - z)**2), 1e-5)  # y == z upto numerics
Example #3
0
    def test_chunk_grad_memory(self):
        """Test chunking gradient here to exercise accelerator memory usage."""
        layer = tl.Serial(tl.Dense(1024 * 1024), tl.Dense(128))
        chunked = tl.Chunk(layer, 256)

        @fastmath.jit
        def mock_training_step(x, weights, state, rng):
            def compute_mock_loss(weights):
                logits, new_state = chunked.pure_fn(x, weights, state, rng)
                loss = fastmath.numpy.mean(logits)
                return loss, (new_state, logits)

            gradients, (new_state,
                        logits) = fastmath.grad(compute_mock_loss,
                                                has_aux=True)(weights)
            new_weights = fastmath.nested_map_multiarg(
                lambda w, g: w - 1e-4 * g, weights, gradients)
            return new_weights, new_state, logits

        x = np.random.uniform(size=(16 * 1024, 16))
        chunked.init(shapes.signature(x))
        weights, _, logits = mock_training_step(x, chunked.weights,
                                                chunked.state,
                                                fastmath.random.get_prng(0))
        self.assertEqual(logits.shape, (16 * 1024, 128))
        self.assertEqual(weights[1][0][0][0].shape, (16, 1024 * 1024))
def ChunkedFeedForward(d_model, d_ff, dropout, activation, act_dropout,
                       chunk_size, mode):
    """Chunked feed-forward block with layer normalization at start."""
    ff = FeedForward(d_model, d_ff, dropout, activation, act_dropout, mode)
    if chunk_size < 1:
        return ff
    return tl.BatchLeadingAxes(tl.Chunk(tl.Serial(ff), chunk_size))
Example #5
0
 def test_chunk_memory(self):
     """Test chunking here to exercise accelerator memory usage."""
     layer = tl.Serial(tl.Dense(1024 * 1024), tl.Dense(128))
     chunked = tl.Chunk(layer, 256)
     x = np.random.uniform(size=(16 * 1024, 16))
     chunked.init(shapes.signature(x))
     y = chunked(x)
     z = tl.Accelerate(chunked)(x)
     self.assertEqual(y.shape, (16 * 1024, 128))
     self.assertEqual(z.shape, (16 * 1024, 128))
def ApplyAttentionLayer(attention_type, d_model, n_heads, d_qk, d_v, causal,
                        masked, attention_dropout, output_dropout,
                        attention_chunk_size, mode):
  """Runs the supplied attention layer."""
  try:
    attention = attention_type(
        n_heads=n_heads,
        d_qk=d_qk,
        d_v=d_v,
        causal=causal,
        masked=masked,
        output_dropout=output_dropout,
        attention_dropout=attention_dropout,
        mode=mode)
  except TypeError:  # No d_qk arguments in less advanced layers.
    attention = attention_type(
        d_model, n_heads=n_heads, dropout=attention_dropout, mode=mode)
  return tl.Chunk(attention, attention_chunk_size)
Example #7
0
def FeedForwardWithOptions(d_model,
                           d_ff,
                           dropout,
                           dropout_shared_axes,
                           ff_activation,
                           ff_dropout,
                           ff_chunk_size,
                           ff_use_sru,
                           ff_sparsity,
                           mode,
                           use_bfloat16=False,
                           ff_sparsity_type='1inN'):
    """Feed-Forward block with all the options.

  Args:
    d_model: Final dimension of tensors at most points in the model, including
      the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each block.
    dropout: Stochastic rate (probability) for dropping an activation value when
      applying dropout within a block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing
      along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful
      way to save memory and apply consistent masks to activation vectors at
      different sequence positions.
    ff_activation: Type of activation function at the end of each block; must be
      an activation-type subclass of `Layer`.
    ff_dropout: Stochastic rate (probability) for dropping an activation value
      when applying dropout after the FF dense layer.
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers
      in addition to the feed-forward block (second int specifies sru size)
    ff_sparsity: int, tuple or string; if not 0, use sparse feed-forward block
      with this sparsity
    mode: If `'train'`, each block will include dropout; else, it will pass all
      values through unaltered.
    use_bfloat16: whether to use bfloat16 for weights (default: False).
    ff_sparsity_type: string, if ff_sparsity >0,
      use SparseFF if ff_sparsity_type=`'1inN'` and
      use BlockSparseFF if ff_sparsity_type=`'Block'`
      use SwitchSparseFF if ff_sparsity_type=`'Switch'`

  Returns:
    A list of layers which maps vectors to vectors.
  """
    if ff_sparsity and ff_sparsity_type == '1inN':
        temperature, quant_prob = 0.1, 0.3
        if isinstance(ff_sparsity, str):
            # This is hacky but used to pass ff_sparsity in yaml sweep files.
            ff_sparsity = [(float(x) if '.' in x else int(x))
                           for x in ff_sparsity.split()]
        if isinstance(ff_sparsity, (list, tuple)):
            if len(ff_sparsity) == 2:
                n_elements_in_block, d_lowrank = ff_sparsity
            else:
                n_elements_in_block, d_lowrank, temperature, quant_prob = ff_sparsity
        else:
            assert isinstance(ff_sparsity, int)
            n_elements_in_block, d_lowrank = ff_sparsity, d_ff // ff_sparsity
        ff = tl.SparseFF(d_ff,
                         n_elements_in_block=n_elements_in_block,
                         d_lowrank=d_lowrank,
                         temperature=temperature,
                         quant_prob=quant_prob,
                         use_bfloat16=use_bfloat16,
                         mode=mode,
                         dropout_rate=dropout,
                         dropout_shared_axes=dropout_shared_axes,
                         ff_chunk_size=ff_chunk_size)
    elif ff_sparsity and ff_sparsity_type == 'Block':
        ff = tl.BlockSparseFF(d_ff, n_experts=ff_sparsity, mode=mode)
    elif ff_sparsity and ff_sparsity_type == 'Switch':
        ff = tl.SwitchSparseFF(d_ff, n_experts=ff_sparsity, mode=mode)
    else:
        ff = _FeedForward(d_model, d_ff, dropout, ff_activation, ff_dropout,
                          use_bfloat16, mode)
    res = [tl.LayerNorm(), ff]
    if ff_sparsity_type != '1inN' or ff_sparsity == 0:
        # SparseFF has Dropout and BatchLeadingAxes built-in.
        res.append(
            tl.Dropout(rate=dropout,
                       shared_axes=dropout_shared_axes,
                       mode=mode))
        if ff_chunk_size > 0:
            res = tl.BatchLeadingAxes(tl.Chunk(tl.Serial(res), ff_chunk_size))
    if ff_use_sru:
        if isinstance(ff_use_sru, (list, tuple)):
            sru_n_layers, sru_n_units = ff_use_sru
        else:
            sru_n_layers, sru_n_units = ff_use_sru, 32
        sru = [tl.SRU(sru_n_units, mode=mode) for _ in range(sru_n_layers)]
        block = [tl.LayerNorm(), tl.Dense(sru_n_units)
                 ] + sru + [tl.Dense(d_model)]
        res = tl.Residual(block, shortcut=res)
    return [res]
Example #8
0
def extract_reversible_blocks(layers, loss_chunk_size=0):
  """Extracts blocks and loss layer for use with ReversibleSerialTrainer.

  Args:
    layers: a list of layers of a single layer to extract blocks from;
      should end with a loss, e.g., [model, loss] or tl.Serial(model, loss).
    loss_chunk_size: int, if > 0 creates a chunked loss layer to save memory
      in models with larger vocabulary; requires the last sublayers of loss
      are [Dense, LogSoftmax, _CrossEntropy, _WeightedMean] in that order.

  Returns:
    a pair (blocks, loss_layer) to use with ReversibleSerialTrainer.
  """
  def _flatten(l):
    """Flatten all Serial layers and sub(sub-...) layers into a list."""
    if isinstance(l, (list, tuple)):
      return [x for layer in l for x in _flatten(layer)]  # pylint: disable=g-complex-comprehension
    elif isinstance(l, tl.Serial):
      return _flatten(l.sublayers)
    else:
      return [l]

  # Extract standard and reversible layer blocks.
  blocks, std_layers, rev_layers = [], [], []
  for layer in _flatten(layers):
    if isinstance(layer, tl.ReversibleLayer):
      rev_layers.append(layer)
    elif not rev_layers:
      std_layers.append(layer)
    else:
      blocks.append((std_layers, rev_layers))
      std_layers, rev_layers = [], []
      std_layers.append(layer)
  if rev_layers:
    raise ValueError('The final layer must be a standard loss, not reversible.')
  if loss_chunk_size > 0:
    # For now we only do chunking of [Dense, LogSoftmax, CrossEntopy, Mean]
    # Let's check that these are the last 4 layers.
    if len(std_layers) < 4:
      raise ValueError('Too short loss layer for chunking')
    # To check for Dense, remove the n_units part from name.
    name4 = std_layers[-4].name[:5]  # Just 'Dense' not e.g., 'Dense_32000'.
    last_4_names = ' '.join([name4] + [l.name for l in std_layers[-3:]])
    if last_4_names != 'Dense LogSoftmax _CrossEntropy _WeightedMean':
      raise ValueError('Loss chunking only works with last layers being "Dense'
                       ' LogSoftmax, _CrossEntropy, _WeightedMean" but got: ' +
                       last_4_names)
    # Create chunked dense+logsoftmax+cross-entropy-loss.
    chunked_xent = tl.Chunk(tl.Serial(std_layers[-4:-1]), loss_chunk_size)
    # The chunked loss should operate on a merged batch dimension, e.g.,
    # including both length and batch size. Need to merge and un-merge later.
    def _reshape_to_batch_and_copy_targets(preds, targets):
      batched_preds = jnp.reshape(preds, [-1, preds.shape[-1]])
      batched_targets = jnp.reshape(targets, [-1])
      return batched_preds, batched_targets, targets
    def _reshape_xent_back(xent, targets):
      return jnp.reshape(xent, targets.shape)
    batched_xent = tl.Serial(
        tl.Fn('pre_xent_rebatch', _reshape_to_batch_and_copy_targets, n_out=3),
        chunked_xent,
        tl.Fn('after_xent_rebatch', _reshape_xent_back)
    )
    loss_layer = tl.Serial(std_layers[:-4] + [batched_xent], std_layers[-1])
  else:
    loss_layer = tl.Serial(std_layers)
  return blocks, loss_layer
def FeedForwardWithOptions(d_model,
                           d_ff,
                           dropout,
                           dropout_shared_axes,
                           ff_activation,
                           ff_dropout,
                           ff_chunk_size,
                           ff_use_sru,
                           ff_sparsity,
                           mode,
                           ff_sparsity_type='1inN'):
    """Feed-Forward block with all the options.

  Args:
    d_model: Final dimension of tensors at most points in the model, including
      the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each block.
    dropout: Stochastic rate (probability) for dropping an activation value when
      applying dropout within a block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing
      along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful
      way to save memory and apply consistent masks to activation vectors at
      different sequence positions.
    ff_activation: Type of activation function at the end of each block; must be
      an activation-type subclass of `Layer`.
    ff_dropout: Stochastic rate (probability) for dropping an activation value
      when applying dropout after the FF dense layer.
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    mode: If `'train'`, each block will include dropout; else, it will pass all
      values through unaltered.
    ff_sparsity_type: string, if ff_sparsity >0,
      use SparseFF if ff_sparsity_type=`'1inN'` and
      use BlockSparseFF if ff_sparsity_type=`'Block'`

  Returns:
    A list of layers which maps vectors to vectors.
  """
    if ff_use_sru:
        return [tl.SRU(d_model) for _ in range(ff_use_sru)]
    elif ff_sparsity and ff_sparsity_type == '1inN':
        ff = tl.SparseFF(d_ff,
                         n_elements_in_block=ff_sparsity,
                         d_lowrank=d_ff // ff_sparsity,
                         mode=mode)
        if ff_chunk_size < 1:
            chunked_ff = ff
        else:
            chunked_ff = tl.BatchLeadingAxes(
                tl.Chunk(tl.Serial(ff), ff_chunk_size))
        return [
            tl.LayerNorm(), chunked_ff,
            tl.Dropout(rate=dropout,
                       shared_axes=dropout_shared_axes,
                       mode=mode)
        ]
    elif ff_sparsity and ff_sparsity_type == 'Block':
        return [
            tl.LayerNorm(),
            tl.BlockSparseFF(d_ff, num_experts=ff_sparsity, mode=mode),
            tl.Dropout(rate=dropout,
                       shared_axes=dropout_shared_axes,
                       mode=mode)
        ]
    else:
        return [
            ChunkedFeedForward(d_model, d_ff, dropout, ff_activation,
                               ff_dropout, ff_chunk_size, mode)
        ]