def test_train_eval_predict_sm3(self): with self.tmp_dir() as output_dir: # Prepare model and inputs n_classes = 4 train_steps = 2 eval_steps = 2 model_fn = functools.partial(models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = lambda _: test_inputs(n_classes) # Train and evaluate state = trax.train(output_dir, model=model_fn, inputs=inputs, train_steps=train_steps, eval_steps=eval_steps, optimizer=trax_opt.SM3) # Assert total train steps self.assertEqual(train_steps, state.step) # Assert 2 evaluations ran train_acc = state.history.get("train", "metrics/accuracy") eval_acc = state.history.get("eval", "metrics/accuracy") self.assertEqual(len(train_acc), len(eval_acc)) self.assertEqual(2, len(eval_acc)) # Predict with final params inputs = inputs(1).train_stream() model = layers.Serial(model_fn()) model(next(inputs)[0], state.params[0])
def Generator(encoded_target): return layers.Serial( encoded_target, layers.Dense(target_vocab_size, kernel_initializer=layers.XavierUniformInitializer()), layers.LogSoftmax )
def SumLearnedPick(positions): """Get a pair (vec, pos) and pick new pos.""" succ_keys = positions[:-1, :] succ_values = positions[1:, :] subtract_1_keys = positions[1:, :] subtract_1_values = positions[:-1, :] l = int(positions.shape[0]) // 2 add_keys = np.array([np.concatenate([positions[i, :], positions[j, :]]) for i in range(l) for j in range(l)]) add_values = np.array([positions[i + j, :] for i in range(l) for j in range(l)]) # TODO(lukaszkaiser): try this below: "for j in range(i) for i in range(2*l)" sub_keys = np.array([np.concatenate([positions[i, :], positions[j, :]]) for j in range(l) for i in range(l)]) sub_values = np.array([positions[max(i - j, 0), :] for j in range(l) for i in range(l)]) return tl.Serial( tl.Branch( LearnedQP(), LearnedQP(keys=succ_keys, values=succ_values), LearnedQP(keys=subtract_1_keys, values=subtract_1_values), LearnedQP(keys=add_keys, values=add_values, binary=True), LearnedQP(keys=sub_keys, values=sub_values, binary=True), ), Unnest(), SoftmaxBranches(n_branches=5) )
def PositionLookupTransformerLM(vocab_size=128, d_feature=256, d_feedforward=512, n_layers=3, n_heads=4, dropout=0.1, max_len=100, mode='train'): """Transformer language model (only uses the decoder part of Transformer). Args: vocab_size: int: vocab size d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_layers: int: number of layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: maximal length mode: str: 'train' or 'eval' Returns: the layer. """ positions = _POSITIONS[:max_len, :] return tl.Serial([ tl.ShiftRight(), tl.Embedding(d_feature, vocab_size), tl.Dropout(rate=dropout, mode=mode), NewPositionalEncoding(positions=positions), [DecoderLayer(positions, d_feature, d_feedforward, n_heads, dropout, mode) for _ in range(n_layers)], PreservePosition(tl.LayerNorm()), tl.Dense(vocab_size), tl.LogSoftmax() ])
def test_train_eval_predict_sm3(self, backend_name): if xla_bridge.device_count() > 1 and backend_name == "tf": self.skipTest( "tf-numpy backend doesn't support multi-devices yet.") with backend.use_backend(backend_name), self.tmp_dir() as output_dir: # Prepare model and inputs n_classes = 4 train_steps = 2 eval_steps = 2 model_fn = functools.partial(models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = lambda _: test_inputs(n_classes) # Train and evaluate state = trax.train(output_dir, model=model_fn, inputs=inputs, train_steps=train_steps, eval_steps=eval_steps, optimizer=trax_opt.SM3) # Assert total train steps self.assertEqual(train_steps, state.step) # Assert 2 evaluations ran train_acc = state.history.get("train", "metrics/accuracy") eval_acc = state.history.get("eval", "metrics/accuracy") self.assertEqual(len(train_acc), len(eval_acc)) self.assertLen(eval_acc, 2) # Predict with final params inputs = inputs(1).train_stream() model = layers.Serial(model_fn()) model(next(inputs)[0], params=state.opt_state.params)
def DecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer decoder layer. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ return layers.Serial( layers.Residual( # Self-attention block. layers.LayerNorm(), layers.Branch(), layers.Parallel( layers.Identity(), # activation for (q, k, v) layers.CausalMask(axis=-2)), # attention mask layers.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), layers.Dropout(rate=dropout, mode=mode)), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode))
def DecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer decoder layer. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ return tl.Serial( tl.Residual( # Self-attention block. tl.LayerNorm(), tl.Branch(tl.Copy(), tl.CausalMask(axis=-2)), # Create mask. tl.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Select(0), # Drop the mask. tl.Dropout(rate=dropout, mode=mode)), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode))
def EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer encoder layer. The input to the encoder is a pair (embedded source, mask) where the mask is created from the original source to prevent attending to the padding part of the input. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a pair (actiavtions, mask). """ return tl.Serial( tl.Residual( # Attention block here. tl.Parallel(tl.LayerNorm(), tl.Copy()), tl.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Parallel(tl.Dropout(rate=dropout, mode=mode), tl.Copy())), tl.Parallel( ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode), tl.Div( divisor=2.0) # Mask added to itself in the residual, divide. ))
def WideResnet(num_blocks=3, hidden_size=64, num_output_classes=10, mode='train'): """WideResnet from https://arxiv.org/pdf/1605.07146.pdf. Args: num_blocks: int, number of blocks in a group. hidden_size: the size of the first hidden layer (multiplied later). num_output_classes: int, number of classes to distinguish. mode: is it training or eval. Returns: The WideResnet model with given layer and output sizes. """ del mode return tl.Serial( tl.Conv(hidden_size, (3, 3), padding='SAME'), WideResnetGroup(num_blocks, hidden_size), WideResnetGroup(num_blocks, hidden_size * 2, (2, 2)), WideResnetGroup(num_blocks, hidden_size * 4, (2, 2)), tl.BatchNorm(), tl.Relu(), tl.AvgPool(pool_size=(8, 8)), tl.Flatten(), tl.Dense(num_output_classes), tl.LogSoftmax() )
def ChunkedDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, chunk_selector, mode): """Transformer decoder layer operating on chunks. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) chunk_selector: a function from chunk number to list of chunks to attend. mode: str: 'train' or 'eval' Returns: the layer. """ return layers.Serial( layers.Residual( # Self-attention block. layers.Map(layers.LayerNorm()), layers.ChunkedCausalMultiHeadedAttention( feature_depth, num_heads=num_heads, dropout=dropout, chunk_selector=chunk_selector, mode=mode), layers.Map(layers.Dropout(rate=dropout, mode=mode)), ), layers.Map( ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode)))
def _jit_predict_fn(model_predict, metric_fn, n_devices, jit=True): """Returns a JIT-compiled predict function (unless jit=False).""" model_predict = layers.Serial([model_predict, metric_fn]) if n_devices == 1: return backend.jit(model_predict) if jit else model_predict # Multi-devices, pmap and run. @functools.partial(backend.pmap, axis_name="batch") def mapped_predict(x, params, state, rng): return model_predict(x, params=params, state=state, rng=rng) def predict(x, params=(), state=(), rng=None): """Predict function jited and parallelized as requested.""" pred = mapped_predict(reshape_by_device(x, n_devices), params, state, jax_random.split(rng, n_devices)) # Need to reduce the [device, per-device-batch, ...] tensors back to # a [batch, ...] tensor. The tensors may be nested. def combine(x): if len(x.shape) > 1: batch_size = x.shape[0] * x.shape[1] return np.reshape(x, [batch_size] + list(x.shape[2:])) # TODO(lukaszkaiser): is returning averages for scalars the right choice? # If it is only scalar, return the average. return np.mean(x, axis=0) return layers.nested_map(pred, combine) return predict
def policy_and_value_net(rng_key, batch_observations_shape, n_actions, bottom_layers_fn=(), two_towers=True): """A policy and value net function.""" # Layers. # Now, with the current logits, one head computes action probabilities and the # other computes the value function. # NOTE: The LogSoftmax instead of the Softmax because of numerical stability. if two_towers: net = tl.Branch( [bottom_layers_fn(), tl.Dense(n_actions), tl.LogSoftmax()], [bottom_layers_fn(), tl.Dense(1)]) else: net = tl.Serial( bottom_layers_fn(), tl.Branch( [tl.Dense(n_actions), tl.LogSoftmax()], [tl.Dense(1)])) return net.initialize(batch_observations_shape, rng_key), net
def __init__(self, pre_attention, attention, post_attention): self.pre_attention = tl.Serial([ # (x1_or_y1, x2) -> (x2, x1_or_y1, x2) tl.Parallel([], tl.Dup()), tl.Swap(), tl.Parallel(pre_attention, [], []), ]) assert hasattr(attention, 'forward_and_vjp') self.attention = ApplyAttentionWrapper(attention) self.post_attention = tl.Parallel(post_attention, [], []) layers = [ self.pre_attention, self.attention, self.post_attention, tl.Parallel(tl.Add(), []), ] super(ReversibleAttentionHalfResidual, self).__init__(layers) self.subtract_top = tl.Parallel(tl.SubtractTop(), []) self.reverse_layers = [ self.pre_attention, self.attention, self.post_attention, self.subtract_top, ]
def ChunkedDecoderLayer(d_feature, d_feedforward, n_heads, dropout, chunk_selector, mode): """Transformer decoder layer operating on chunks. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) chunk_selector: a function from chunk number to list of chunks to attend. mode: str: 'train' or 'eval' Returns: the layer. """ return tl.Serial( Residual( # Self-attention block. tl.Map(tl.LayerNorm()), ChunkedCausalMultiHeadedAttention(d_feature, n_heads=n_heads, dropout=dropout, chunk_selector=chunk_selector, mode=mode), tl.Map(tl.Dropout(rate=dropout, mode=mode)), ), tl.Map( ResidualFeedForward(d_feature, d_feedforward, dropout, mode=mode)))
def PreservePosition(layer): """Execute layer without position but preserve it in parallel.""" return tl.Serial( CutAtPosition(), layer, tl.Concatenate(n_items=2) )
def PreservePosition(layer): """Execute layer without position but preserve it in parallel.""" return tl.Serial( CutPosition(), layer, ConcatenateN() )
def AttentionPosition(positions, d_model, n_heads=8, dropout=0.0, mode='train'): """Transformer-style multi-headed attention.""" return tl.Serial( tl.Dup(), tl.Dup(), tl.Parallel( ApplyAndQueryPositions( tl.Dense(d_model), pos=[SumLearnedPick(positions) for _ in range(n_heads)]), PreservePosition(tl.Dense(d_model)), PreservePosition(tl.Dense(d_model)), ), tl.Parallel( CopyHeadsPos(h=n_heads), MixHeadsPos(h=n_heads), MixHeadsPos(h=n_heads), ), tl.PureAttention(d_model=d_model, n_heads=n_heads, dropout=dropout, mode=mode), tl.Parallel([], tl.Drop()), # Drop the mask. CombineHeadsPos(h=n_heads), PreservePosition(tl.Dense(d_model)), )
def Decoder(memory, target, target_mask, memory_mask): """Transformer decoder stack. Args: memory: layer variable: encoded source sequences target: layer variable: raw target sequences target_mask: layer variable: self-attention mask memory_mask: layer variable: memory attention mask Returns: Layer variable that outputs encoded source. """ decoder_layer = layers.Serial( # target attends to self layers.Residual( layers.LayerNorm(), layers.Branch(size=4), layers.Parallel( layers.Identity(), # query layers.Identity(), # key layers.Identity(), # value target_mask), # attention mask multi_attention, layers.Dropout(dropout, mode=mode)), # target attends to encoded source layers.Residual( layers.LayerNorm(), layers.Branch(size=4), layers.Parallel( layers.Identity(), # query memory, # key memory, # value memory_mask), # attention mask multi_attention, layers.Dropout(dropout, mode=mode)), # feed-forward ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode)) return layers.Serial( target, target_embedding_layer, layers.repeat(decoder_layer, num_layers), layers.LayerNorm(), )
def test_transformer_lm_forward_shape(self): """Run the Transformer LM forward and check output shape.""" vocab_size = 16 input_shape = [3, 5] model = transformer.TransformerLM( vocab_size, d_feature=32, d_feedforward=64, n_layers=2, n_heads=2) final_shape = tl.check_shape_agreement( tl.Serial(model), tuple(input_shape), integer_inputs=True) self.assertEqual(tuple(input_shape + [vocab_size]), final_shape)
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False): """WideResnet convolutational block.""" main = tl.Serial(tl.BatchNorm(), tl.Relu(), tl.Conv(channels, (3, 3), strides, padding='SAME'), tl.BatchNorm(), tl.Relu(), tl.Conv(channels, (3, 3), padding='SAME')) shortcut = tl.Copy() if not channel_mismatch else tl.Conv( channels, (3, 3), strides, padding='SAME') return tl.Residual(main, shortcut=shortcut)
def IdentityBlock(kernel_size, filters): """ResNet identical size block.""" ks = kernel_size filters1, filters2, filters3 = filters main = tl.Serial( tl.Conv(filters1, (1, 1)), tl.BatchNorm(), tl.Relu(), tl.Conv(filters2, (ks, ks), padding='SAME'), tl.BatchNorm(), tl.Relu(), tl.Conv(filters3, (1, 1)), tl.BatchNorm() ) return tl.Serial( tl.Residual(main), tl.Relu() )
def TransformerEncoder(vocab_size, num_classes=10, feature_depth=512, feedforward_depth=2048, num_layers=6, num_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer encoder. Args: vocab_size: int: vocab size num_classes: how many classes on output feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_layers: int: number of encoder/decoder layers num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the Transformer encoder layer. """ input_embedding = layers.Serial( layers.Embedding(feature_depth, vocab_size), layers.Dropout(rate=dropout, mode=mode), layers.PositionalEncoding(max_len=max_len) ) return layers.Serial( layers.Branch(), # Branch input to create embedding and mask. layers.Parallel(input_embedding, layers.PaddingMask()), layers.Serial(*[EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers)]), layers.FirstBranch(), # Drop the mask. layers.LayerNorm(), layers.Mean(axis=1), # Average on length. layers.Dense(num_classes), layers.LogSoftmax() )
def __init__(self, layer, check_shapes=True): super(Map, self).__init__() if layer is None or isinstance(layer, (list, tuple)): layer = tl.Serial(layer) self._layer = layer # Generally a Map should be applied to lists where all elements have # the same shape -- because self._layer will only be initialized once # and it could have different parameters for different shapes. But there # are valid cases -- e.g., when self._layer has no parameters -- where we # can apply Map to different shapes -- set check_shapes=False in such cases. self._check_shapes = check_shapes
def ChunkedCausalMultiHeadedAttention(d_feature, n_heads=8, dropout=0.0, chunk_selector=None, mode='train'): """Transformer-style causal multi-headed attention operating on chunks. Accepts inputs that are a list of chunks and applies causal attention. Args: d_feature: int: depth of embedding n_heads: int: number of attention heads dropout: float: dropout rate chunk_selector: a function from chunk number to list of chunks to attend. mode: str: 'train' or 'eval' Returns: Multi-headed self-attention layer. """ prepare_attention_input = tl.Serial( tl.Branch( tl.Branch( # q = k = v = first input tl.NoOp(), tl.NoOp(), tl.NoOp()), tl.CausalMask(axis=-2), ), tl.Parallel( tl.Parallel( tl.Dense(d_feature), tl.Dense(d_feature), tl.Dense(d_feature), ), tl.NoOp())) return tl.Serial( tl.Map(prepare_attention_input), ChunkedAttentionSelector(selector=chunk_selector), # pylint: disable=no-value-for-parameter tl.Map(tl.PureMultiHeadedAttention(d_feature=d_feature, n_heads=n_heads, dropout=dropout, mode=mode), check_shapes=False), tl.Map(tl.Select(0), check_shapes=False), # drop masks tl.Map(tl.Dense(d_feature)))
def policy_and_value_net(rng_key, batch_observations_shape, num_actions, bottom_layers=None): """A policy and value net function.""" # Layers. cur_layers = [] if bottom_layers is not None: cur_layers.extend(bottom_layers) # Now, with the current logits, one head computes action probabilities and the # other computes the value function. # NOTE: The LogSoftmax instead of the Softmax because of numerical stability. cur_layers.extend([ layers.Branch( layers.Serial(layers.Dense(num_actions), layers.LogSoftmax()), layers.Dense(1)) ]) net = layers.Serial(*cur_layers) return net.initialize(batch_observations_shape, rng_key), net
def MLP(num_hidden_layers=2, hidden_size=512, activation_fn=tl.Relu, num_output_classes=10, mode="train"): """Multi-layer feed-forward neural network with non-linear activations.""" del mode cur_layers = [tl.Flatten()] for _ in range(num_hidden_layers): cur_layers += [tl.Dense(hidden_size), activation_fn()] cur_layers += [tl.Dense(num_output_classes), tl.LogSoftmax()] return tl.Serial(*cur_layers)
def __init__(self, residual_layers): self.compute_residual = tl.Serial([ # TODO(jonni): Rewrite without using Select. tl.Select(inputs=('x1_or_y1', 'x2'), output=('x2', 'x1_or_y1', 'x2')), tl.Parallel(residual_layers, [], []), ]) layers = [self.compute_residual, tl.Add()] super(ReversibleHalfResidual, self).__init__(layers) self.subtract_top = tl.SubtractTop() self.reverse_layers = [self.compute_residual, self.subtract_top]
def __init__(self, residual_layers): self.compute_residual = tl.Serial([ # (x1_or_y1, x2) -> (x2, x1_or_y1, x2) tl.Parallel([], tl.Dup()), tl.Swap(), tl.Parallel(residual_layers, [], []), ]) layers = [self.compute_residual, tl.Parallel(tl.Add(), [])] super(ReversibleHalfResidual, self).__init__(layers) self.subtract_top = tl.Parallel(tl.SubtractTop(), []) self.reverse_layers = [self.compute_residual, self.subtract_top]
def ConvBlock(kernel_size, filters, strides): """ResNet convolutional striding block.""" ks = kernel_size filters1, filters2, filters3 = filters main = tl.Serial( tl.Conv(filters1, (1, 1), strides), tl.BatchNorm(), tl.Relu(), tl.Conv(filters2, (ks, ks), padding='SAME'), tl.BatchNorm(), tl.Relu(), tl.Conv(filters3, (1, 1)), tl.BatchNorm() ) shortcut = tl.Serial( tl.Conv(filters3, (1, 1), strides), tl.BatchNorm() ) return tl.Serial( tl.Residual(main, shortcut=shortcut), tl.Relu() )
def EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer encoder layer. The input to the encoder is a pair (embedded source, mask) where the mask is created from the original source to prevent attending to the padding part of the input. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a pair (actiavtions, mask). """ # The encoder block expects (activation, mask) as input and returns # the new activations only, we add the mask back to output next. encoder_block = layers.Serial( layers.Residual( # Attention block here. layers.Parallel(layers.LayerNorm(), layers.Identity()), layers.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), layers.Dropout(rate=dropout, mode=mode), shortcut=layers.FirstBranch() ), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode) ) # Now we add the mask back. return layers.Serial( layers.Reorder(output=((0, 1), 1)), # (x, mask) --> ((x, mask), mask) layers.Parallel(encoder_block, layers.Identity()) )