def test_chunk_uneven_numbers(self): layer = tl.Dense(4) x = np.array([[1, 2, 3], [4, 5, 6]]) layer.init(x) y = layer(x) z = tl.Chunk(layer, 3)(x) # By default it should just pass self.assertLess(np.sum((y - z)**2), 1e-5) # y == z upto numerics chunk_with_test = tl.Chunk(layer, 3, pass_unchunkable=False) self.assertRaises(tl.LayerError, lambda: chunk_with_test(x))
def test_chunk(self): layer = tl.Dense(4) x = np.array([[1, 2, 3], [4, 5, 6]]) layer.init(x) y = layer(x) z = tl.Chunk(layer, 1)(x) self.assertLess(np.sum((y - z)**2), 1e-5) # y == z upto numerics
def test_chunk_grad_memory(self): """Test chunking gradient here to exercise accelerator memory usage.""" layer = tl.Serial(tl.Dense(1024 * 1024), tl.Dense(128)) chunked = tl.Chunk(layer, 256) @fastmath.jit def mock_training_step(x, weights, state, rng): def compute_mock_loss(weights): logits, new_state = chunked.pure_fn(x, weights, state, rng) loss = fastmath.numpy.mean(logits) return loss, (new_state, logits) gradients, (new_state, logits) = fastmath.grad(compute_mock_loss, has_aux=True)(weights) new_weights = fastmath.nested_map_multiarg( lambda w, g: w - 1e-4 * g, weights, gradients) return new_weights, new_state, logits x = np.random.uniform(size=(16 * 1024, 16)) chunked.init(shapes.signature(x)) weights, _, logits = mock_training_step(x, chunked.weights, chunked.state, fastmath.random.get_prng(0)) self.assertEqual(logits.shape, (16 * 1024, 128)) self.assertEqual(weights[1][0][0][0].shape, (16, 1024 * 1024))
def ChunkedFeedForward(d_model, d_ff, dropout, activation, act_dropout, chunk_size, mode): """Chunked feed-forward block with layer normalization at start.""" ff = FeedForward(d_model, d_ff, dropout, activation, act_dropout, mode) if chunk_size < 1: return ff return tl.BatchLeadingAxes(tl.Chunk(tl.Serial(ff), chunk_size))
def test_chunk_memory(self): """Test chunking here to exercise accelerator memory usage.""" layer = tl.Serial(tl.Dense(1024 * 1024), tl.Dense(128)) chunked = tl.Chunk(layer, 256) x = np.random.uniform(size=(16 * 1024, 16)) chunked.init(shapes.signature(x)) y = chunked(x) z = tl.Accelerate(chunked)(x) self.assertEqual(y.shape, (16 * 1024, 128)) self.assertEqual(z.shape, (16 * 1024, 128))
def ApplyAttentionLayer(attention_type, d_model, n_heads, d_qk, d_v, causal, masked, attention_dropout, output_dropout, attention_chunk_size, mode): """Runs the supplied attention layer.""" try: attention = attention_type( n_heads=n_heads, d_qk=d_qk, d_v=d_v, causal=causal, masked=masked, output_dropout=output_dropout, attention_dropout=attention_dropout, mode=mode) except TypeError: # No d_qk arguments in less advanced layers. attention = attention_type( d_model, n_heads=n_heads, dropout=attention_dropout, mode=mode) return tl.Chunk(attention, attention_chunk_size)
def FeedForwardWithOptions(d_model, d_ff, dropout, dropout_shared_axes, ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, mode, use_bfloat16=False, ff_sparsity_type='1inN'): """Feed-Forward block with all the options. Args: d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each block. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within a block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. ff_activation: Type of activation function at the end of each block; must be an activation-type subclass of `Layer`. ff_dropout: Stochastic rate (probability) for dropping an activation value when applying dropout after the FF dense layer. ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers in addition to the feed-forward block (second int specifies sru size) ff_sparsity: int, tuple or string; if not 0, use sparse feed-forward block with this sparsity mode: If `'train'`, each block will include dropout; else, it will pass all values through unaltered. use_bfloat16: whether to use bfloat16 for weights (default: False). ff_sparsity_type: string, if ff_sparsity >0, use SparseFF if ff_sparsity_type=`'1inN'` and use BlockSparseFF if ff_sparsity_type=`'Block'` use SwitchSparseFF if ff_sparsity_type=`'Switch'` Returns: A list of layers which maps vectors to vectors. """ if ff_sparsity and ff_sparsity_type == '1inN': temperature, quant_prob = 0.1, 0.3 if isinstance(ff_sparsity, str): # This is hacky but used to pass ff_sparsity in yaml sweep files. ff_sparsity = [(float(x) if '.' in x else int(x)) for x in ff_sparsity.split()] if isinstance(ff_sparsity, (list, tuple)): if len(ff_sparsity) == 2: n_elements_in_block, d_lowrank = ff_sparsity else: n_elements_in_block, d_lowrank, temperature, quant_prob = ff_sparsity else: assert isinstance(ff_sparsity, int) n_elements_in_block, d_lowrank = ff_sparsity, d_ff // ff_sparsity ff = tl.SparseFF(d_ff, n_elements_in_block=n_elements_in_block, d_lowrank=d_lowrank, temperature=temperature, quant_prob=quant_prob, use_bfloat16=use_bfloat16, mode=mode, dropout_rate=dropout, dropout_shared_axes=dropout_shared_axes, ff_chunk_size=ff_chunk_size) elif ff_sparsity and ff_sparsity_type == 'Block': ff = tl.BlockSparseFF(d_ff, n_experts=ff_sparsity, mode=mode) elif ff_sparsity and ff_sparsity_type == 'Switch': ff = tl.SwitchSparseFF(d_ff, n_experts=ff_sparsity, mode=mode) else: ff = _FeedForward(d_model, d_ff, dropout, ff_activation, ff_dropout, use_bfloat16, mode) res = [tl.LayerNorm(), ff] if ff_sparsity_type != '1inN' or ff_sparsity == 0: # SparseFF has Dropout and BatchLeadingAxes built-in. res.append( tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)) if ff_chunk_size > 0: res = tl.BatchLeadingAxes(tl.Chunk(tl.Serial(res), ff_chunk_size)) if ff_use_sru: if isinstance(ff_use_sru, (list, tuple)): sru_n_layers, sru_n_units = ff_use_sru else: sru_n_layers, sru_n_units = ff_use_sru, 32 sru = [tl.SRU(sru_n_units, mode=mode) for _ in range(sru_n_layers)] block = [tl.LayerNorm(), tl.Dense(sru_n_units) ] + sru + [tl.Dense(d_model)] res = tl.Residual(block, shortcut=res) return [res]
def extract_reversible_blocks(layers, loss_chunk_size=0): """Extracts blocks and loss layer for use with ReversibleSerialTrainer. Args: layers: a list of layers of a single layer to extract blocks from; should end with a loss, e.g., [model, loss] or tl.Serial(model, loss). loss_chunk_size: int, if > 0 creates a chunked loss layer to save memory in models with larger vocabulary; requires the last sublayers of loss are [Dense, LogSoftmax, _CrossEntropy, _WeightedMean] in that order. Returns: a pair (blocks, loss_layer) to use with ReversibleSerialTrainer. """ def _flatten(l): """Flatten all Serial layers and sub(sub-...) layers into a list.""" if isinstance(l, (list, tuple)): return [x for layer in l for x in _flatten(layer)] # pylint: disable=g-complex-comprehension elif isinstance(l, tl.Serial): return _flatten(l.sublayers) else: return [l] # Extract standard and reversible layer blocks. blocks, std_layers, rev_layers = [], [], [] for layer in _flatten(layers): if isinstance(layer, tl.ReversibleLayer): rev_layers.append(layer) elif not rev_layers: std_layers.append(layer) else: blocks.append((std_layers, rev_layers)) std_layers, rev_layers = [], [] std_layers.append(layer) if rev_layers: raise ValueError('The final layer must be a standard loss, not reversible.') if loss_chunk_size > 0: # For now we only do chunking of [Dense, LogSoftmax, CrossEntopy, Mean] # Let's check that these are the last 4 layers. if len(std_layers) < 4: raise ValueError('Too short loss layer for chunking') # To check for Dense, remove the n_units part from name. name4 = std_layers[-4].name[:5] # Just 'Dense' not e.g., 'Dense_32000'. last_4_names = ' '.join([name4] + [l.name for l in std_layers[-3:]]) if last_4_names != 'Dense LogSoftmax _CrossEntropy _WeightedMean': raise ValueError('Loss chunking only works with last layers being "Dense' ' LogSoftmax, _CrossEntropy, _WeightedMean" but got: ' + last_4_names) # Create chunked dense+logsoftmax+cross-entropy-loss. chunked_xent = tl.Chunk(tl.Serial(std_layers[-4:-1]), loss_chunk_size) # The chunked loss should operate on a merged batch dimension, e.g., # including both length and batch size. Need to merge and un-merge later. def _reshape_to_batch_and_copy_targets(preds, targets): batched_preds = jnp.reshape(preds, [-1, preds.shape[-1]]) batched_targets = jnp.reshape(targets, [-1]) return batched_preds, batched_targets, targets def _reshape_xent_back(xent, targets): return jnp.reshape(xent, targets.shape) batched_xent = tl.Serial( tl.Fn('pre_xent_rebatch', _reshape_to_batch_and_copy_targets, n_out=3), chunked_xent, tl.Fn('after_xent_rebatch', _reshape_xent_back) ) loss_layer = tl.Serial(std_layers[:-4] + [batched_xent], std_layers[-1]) else: loss_layer = tl.Serial(std_layers) return blocks, loss_layer
def FeedForwardWithOptions(d_model, d_ff, dropout, dropout_shared_axes, ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, mode, ff_sparsity_type='1inN'): """Feed-Forward block with all the options. Args: d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each block. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within a block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. ff_activation: Type of activation function at the end of each block; must be an activation-type subclass of `Layer`. ff_dropout: Stochastic rate (probability) for dropping an activation value when applying dropout after the FF dense layer. ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity mode: If `'train'`, each block will include dropout; else, it will pass all values through unaltered. ff_sparsity_type: string, if ff_sparsity >0, use SparseFF if ff_sparsity_type=`'1inN'` and use BlockSparseFF if ff_sparsity_type=`'Block'` Returns: A list of layers which maps vectors to vectors. """ if ff_use_sru: return [tl.SRU(d_model) for _ in range(ff_use_sru)] elif ff_sparsity and ff_sparsity_type == '1inN': ff = tl.SparseFF(d_ff, n_elements_in_block=ff_sparsity, d_lowrank=d_ff // ff_sparsity, mode=mode) if ff_chunk_size < 1: chunked_ff = ff else: chunked_ff = tl.BatchLeadingAxes( tl.Chunk(tl.Serial(ff), ff_chunk_size)) return [ tl.LayerNorm(), chunked_ff, tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) ] elif ff_sparsity and ff_sparsity_type == 'Block': return [ tl.LayerNorm(), tl.BlockSparseFF(d_ff, num_experts=ff_sparsity, mode=mode), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) ] else: return [ ChunkedFeedForward(d_model, d_ff, dropout, ff_activation, ff_dropout, ff_chunk_size, mode) ]