def policy_and_value_net(rng_key, batch_observations_shape, observations_dtype, n_actions, bottom_layers_fn=(), two_towers=True): """A policy and value net function.""" # Layers. # Now, with the current logits, one head computes action probabilities and the # other computes the value function. # NOTE: The LogSoftmax instead of the Softmax because of numerical stability. if two_towers: layers = [ tl.Dup(), tl.Parallel( [bottom_layers_fn(), tl.Dense(n_actions), tl.LogSoftmax()], [bottom_layers_fn(), tl.Dense(1)], ) ] else: layers = [ bottom_layers_fn(), tl.Dup(), tl.Parallel( [tl.Dense(n_actions), tl.LogSoftmax()], [tl.Dense(1)], ) ] net = tl.Model(layers) params, state = net.initialize(batch_observations_shape, observations_dtype, rng_key) return params, state, net
def ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode): """Residual feed-forward layer with normalization at start.""" return layers.Residual(layers.LayerNorm(), layers.Dense(feedforward_depth), layers.Relu(), layers.Dropout(rate=dropout, mode=mode), layers.Dense(feature_depth), layers.Dropout(rate=dropout, mode=mode))
def policy_and_value_net(rng_key, batch_observations_shape, num_actions, bottom_layers_fn=None, two_towers=True): """A policy and value net function.""" # Layers. # Now, with the current logits, one head computes action probabilities and the # other computes the value function. # NOTE: The LogSoftmax instead of the Softmax because of numerical stability. net = None if not two_towers: tower = [] if bottom_layers_fn is None else bottom_layers_fn() tower.extend([ layers.Branch( layers.Serial(layers.Dense(num_actions), layers.LogSoftmax()), layers.Dense(1)) ]) net = layers.Serial(*tower) else: tower1 = [] if bottom_layers_fn is None else bottom_layers_fn() tower2 = [] if bottom_layers_fn is None else bottom_layers_fn() tower1.extend([layers.Dense(num_actions), layers.LogSoftmax()]) tower2.extend([layers.Dense(1)]) net = layers.Branch( layers.Serial(*tower1), layers.Serial(*tower2), ) assert net return net.initialize(batch_observations_shape, rng_key), net
def policy_and_value_net(n_actions, bottom_layers_fn, two_towers): """A policy and value net function.""" # Layers. # Now, with the current logits, one head computes action probabilities and the # other computes the value function. # NOTE: The LogSoftmax instead of the Softmax because of numerical stability. if two_towers: layers = [ tl.Dup(), tl.Parallel( [bottom_layers_fn(), tl.Dense(n_actions), tl.LogSoftmax()], [bottom_layers_fn(), tl.Dense(1)], ) ] else: layers = [ bottom_layers_fn(), tl.Dup(), tl.Parallel( [tl.Dense(n_actions), tl.LogSoftmax()], [tl.Dense(1)], ) ] return tl.Model(layers)
def MultiHeadedAttentionPosition(positions, d_feature, n_heads=8, dropout=0.0, mode='train'): """Transformer-style multi-headed attention.""" return tl.Serial( tl.Dup(), tl.Dup(), tl.Parallel( ApplyAndQueryPositions( tl.Dense(d_feature), pos=[SumLearnedPick(positions) for _ in range(n_heads)]), PreservePosition(tl.Dense(d_feature)), PreservePosition(tl.Dense(d_feature)), ), tl.Parallel( CopyHeadsPos(h=n_heads), MixHeadsPos(h=n_heads), MixHeadsPos(h=n_heads), ), tl.PureMultiHeadedAttention(d_feature=d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Parallel([], tl.Drop()), # Drop the mask. CombineHeadsPos(h=n_heads), PreservePosition(tl.Dense(d_feature)), )
def policy_and_value_net(rng_key, batch_observations_shape, n_actions, bottom_layers_fn=(), two_towers=True): """A policy and value net function.""" # Layers. # Now, with the current logits, one head computes action probabilities and the # other computes the value function. # NOTE: The LogSoftmax instead of the Softmax because of numerical stability. if two_towers: net = tl.Branch( [bottom_layers_fn(), tl.Dense(n_actions), tl.LogSoftmax()], [bottom_layers_fn(), tl.Dense(1)]) else: net = tl.Serial( bottom_layers_fn(), tl.Branch( [tl.Dense(n_actions), tl.LogSoftmax()], [tl.Dense(1)])) return net.initialize(batch_observations_shape, rng_key), net
def DecoderBlock(d_feature, d_feedforward, n_heads, n_attention_chunks, attention_loop_stride, dropout, mode): """Reversible transformer decoder layer. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads n_attention_chunks: int: number of chunks for attention attention_loop_stride: int: number of query elements to compute attention for in parallel. Set to 0 to disable memory-efficient attention. dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ pre_attention = [ Chunk(sections=n_attention_chunks), # pylint: disable=no-value-for-parameter tl.LayerNorm(), tl.Dup(), tl.Dup(), tl.Parallel( [tl.Dense(d_feature), SplitHeads(n_heads=n_heads)], # pylint: disable=no-value-for-parameter [tl.Dense(d_feature), SplitHeads(n_heads=n_heads)], # pylint: disable=no-value-for-parameter [tl.Dense(d_feature), SplitHeads(n_heads=n_heads)], # pylint: disable=no-value-for-parameter ), ] # TODO(kitaev): add dropout if attention_loop_stride < 1: # Use the standard implementation if no loop_stride is provided. attention = DotProductAttention(dropout=None, mode=mode) else: attention = MemoryEfficientDotProductAttention( loop_stride=attention_loop_stride, dropout=None, mode=mode) # ReversibleAttentionHalfResidual requires that post_attention be linear in # its input (so the backward pass can be computed without knowing the input) post_attention = [ JoinHeads(), # pylint: disable=no-value-for-parameter tl.Dense(d_feature), Unchunk(sections=n_attention_chunks), # pylint: disable=no-value-for-parameter ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ ReversibleAttentionHalfResidual(pre_attention, attention, post_attention), ReversibleSwap(), ReversibleHalfResidual(feed_forward), ReversibleSwap(), ]
def policy_and_value_net(n_actions, n_controls, vocab_size, bottom_layers_fn, two_towers): """A policy and value net function.""" # Layers. # Now, with the current logits, one head computes action probabilities and the # other computes the value function. # NOTE: The LogSoftmax instead of the Softmax because of numerical stability. @tl.layer() def FlattenControlsIntoTime(x, **unused_kwargs): # pylint: disable=invalid-name """Splits logits for actions in different controls and flattens controls.""" return np.reshape(x, (x.shape[0], -1, n_actions)) if vocab_size is None: # In continuous policies every element of the output sequence corresponds to # an observation. n_preds_per_input = n_controls kwargs = {} else: # In discrete policies every element of the output sequence corresponds to # a symbol in the discrete representation, and each control takes 1 symbol. n_preds_per_input = 1 kwargs = {"vocab_size": vocab_size} if two_towers: layers = [ tl.Dup(), tl.Parallel( [ bottom_layers_fn(**kwargs), tl.Dense(n_preds_per_input * n_actions), FlattenControlsIntoTime(), # pylint: disable=no-value-for-parameter tl.LogSoftmax() ], [ bottom_layers_fn(**kwargs), tl.Dense(n_preds_per_input), tl.Flatten() ], ) ] else: layers = [ bottom_layers_fn(**kwargs), tl.Dup(), tl.Parallel( [ tl.Dense(n_preds_per_input * n_actions), FlattenControlsIntoTime(), # pylint: disable=no-value-for-parameter tl.LogSoftmax() ], [tl.Dense(n_preds_per_input), tl.Flatten()], ) ] return tl.Model(layers)
def common_layers(): cur_layers = [] if FLAGS.flatten_non_batch_time_dims: cur_layers = [ layers.Div(divisor=255.0), layers.Flatten(num_axis_to_keep=2) ] body = [layers.Dense(64), layers.Tanh(), layers.Dense(64), layers.Tanh()] return cur_layers + body
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type, dropout, mode): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_heads: int: number of attention heads n_attention_chunks: int: number of chunks for attention attention_type: class: attention class to use, such as DotProductAttention. dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ pre_attention = [ Chunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter tl.LayerNorm(), tl.Dup(), tl.Dup(), tl.Parallel( [tl.Dense(d_attention_key * n_heads), SplitHeads(n_heads=n_heads)], # pylint: disable=no-value-for-parameter [tl.Dense(d_attention_key * n_heads), SplitHeads(n_heads=n_heads)], # pylint: disable=no-value-for-parameter [ tl.Dense(d_attention_value * n_heads), SplitHeads(n_heads=n_heads) ], # pylint: disable=no-value-for-parameter ), ] attention = attention_type(mode=mode) # ReversibleAttentionHalfResidual requires that post_attention be linear in # its input (so the backward pass can be computed without knowing the input) post_attention = [ JoinHeads(), # pylint: disable=no-value-for-parameter tl.Dense(d_model), Unchunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter ] feed_forward = [ FeedForward(d_model, d_ff, dropout, mode=mode), ] return [ ReversibleAttentionHalfResidual(pre_attention, attention, post_attention), tl.ReversibleSwap(), ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]
def FeedForward(d_feature, d_feedforward, dropout, mode): """Feed-forward block with layer normalization at start.""" return [ tl.LayerNorm(), tl.Dense(d_feedforward), tl.Relu(), tl.Dropout(rate=dropout, mode=mode), tl.Dense(d_feature), tl.Dropout(rate=dropout, mode=mode), ]
def FeedForward(d_model, d_ff, dropout, mode): """Feed-forward block with layer normalization at start.""" return [ tl.LayerNorm(), tl.Dense(d_ff), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter tl.Relu(), tl.Dense(d_model), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter ]
def FeedForward(d_model, d_ff, dropout, layer_idx, mode): """Feed-forward block with layer normalization at start.""" return [ tl.LayerNorm(), tl.Dense(d_ff), tl.Relu(), tl.Dropout(rate=dropout, name='ff_middle_%d' % layer_idx, mode=mode), tl.Dense(d_model), tl.Dropout(rate=dropout, name='ff_final_%d' % layer_idx, mode=mode), ]
def FrameStackMLP(n_frames=4, hidden_sizes=(64,), output_size=64, mode='train'): """MLP operating on a fixed number of last frames.""" del mode return tl.Model( FrameStack(n_frames=n_frames), [[tl.Dense(d_hidden), tl.Relu()] for d_hidden in hidden_sizes], tl.Dense(output_size), )
def ResidualFeedForward(d_feature, d_feedforward, dropout, mode): """Residual feed-forward layer with normalization at start.""" return Residual( tl.LayerNorm(), tl.Dense(d_feedforward), tl.Relu(), tl.Dropout(rate=dropout, mode=mode), tl.Dense(d_feature), tl.Dropout(rate=dropout, mode=mode) )
def FeedForward(d_feature, d_feedforward, dropout, mode): """Feed-forward block with layer normalization at start.""" # TODO(kitaev): dropout is disabled to save memory del dropout, mode return [ tl.LayerNorm(), tl.Dense(d_feedforward), tl.Relu(), # tl.Dropout(rate=dropout, mode=mode), tl.Dense(d_feature), # tl.Dropout(rate=dropout, mode=mode), ]
def MLP(num_hidden_layers=2, hidden_size=512, activation_fn=tl.Relu, num_output_classes=10, mode="train"): """Multi-layer feed-forward neural network with non-linear activations.""" del mode cur_layers = [tl.Flatten()] for _ in range(num_hidden_layers): cur_layers += [tl.Dense(hidden_size), activation_fn()] cur_layers += [tl.Dense(num_output_classes), tl.LogSoftmax()] return tl.Serial(*cur_layers)
def FeedForward(d_model, d_ff, dropout, mode): """Feed-forward block with layer normalization at start.""" # TODO(kitaev): add dropout. Dropout is typically performed by adding noise to # the activations, but when the size of the activations is very large it is # more efficient to add noise to the *parameters* instead. del dropout, mode return [ tl.LayerNorm(), tl.Dense(d_ff), tl.Relu(), tl.Dense(d_model), ]
def common_layers(): cur_layers = [] if FLAGS.env_name == "Pong-v0": cur_layers = [ layers.Div(divisor=255.0), layers.Flatten(num_axis_to_keep=2) ] return cur_layers + [ layers.Dense(16), layers.Relu(), layers.Dense(4), layers.Relu() ]
def common_layers(): # TODO(afrozm): Refactor. if "NoFrameskip" in FLAGS.env_problem_name: return atari_layers() cur_layers = [] if FLAGS.flatten_dims: cur_layers = [ layers.Div(divisor=255.0), layers.Flatten(num_axis_to_keep=2) ] body = [layers.Dense(64), layers.Tanh(), layers.Dense(64), layers.Tanh()] return cur_layers + body
def test_dense_param_sharing(self): model1 = layers.Serial(layers.Dense(32), layers.Dense(32)) layer = layers.Dense(32) model2 = layers.Serial(layer, layer) rng = random.get_prng(0) params1 = model1.initialize((-1, 32), rng) params2 = model2.initialize((-1, 32), rng) # The first parameters have 2 kernels of size (32, 32). self.assertEqual((32, 32), params1[0][0].shape) self.assertEqual((32, 32), params1[1][0].shape) # The second parameters have 1 kernel of size (32, 32) and an empty dict. self.assertEqual((32, 32), params2[0][0].shape) self.assertEqual((), params2[1])
def ResidualFeedForward(d_feature, d_feedforward, dropout, mode): """Residual feed-forward layer with normalization at start.""" stack = tl.Serial( tl.LayerNorm(), tl.Dense(d_feedforward), tl.Relu(), tl.Dropout(rate=dropout, mode=mode), tl.Dense(d_feature), tl.Dropout(rate=dropout, mode=mode) ) return tl.Residual(PreservePosition(stack))
def MLP(n_hidden_layers=2, d_hidden=512, activation_fn=tl.Relu, n_output_classes=10, mode="train"): """Multi-layer feed-forward neural network with non-linear activations.""" del mode return [ tl.Flatten(), [[tl.Dense(d_hidden), activation_fn()] for _ in range(n_hidden_layers)], tl.Dense(n_output_classes), tl.LogSoftmax(), ]
def MLP(n_hidden_layers=2, d_hidden=512, activation_fn=tl.Relu, n_output_classes=10, mode="train"): """A multi-layer feedforward (perceptron) network.""" del mode return tl.Model( tl.Flatten(), [[tl.Dense(d_hidden), activation_fn()] for _ in range(n_hidden_layers)], tl.Dense(n_output_classes), tl.LogSoftmax(), )
def model(mode): del mode return layers.Serial( layers.Parallel( layers.Flatten(), # Observation stack. layers.Embedding(d_feature=1, vocab_size=n_actions), # Action. ), layers.Concatenate(), layers.Dense(n_units=1), layers.Dup(), layers.Parallel( layers.Dense(n_units=obs_shape[1]), # New observation. None, # Reward. ) )
def ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode): """Residual feed-forward layer with normalization at start.""" return layers.Residual( layers.LayerNorm(), layers.Dense(feedforward_depth, kernel_initializer=layers.XavierUniformInitializer()), layers.Relu(), layers.Dropout(rate=dropout, mode=mode), layers.Dense(feature_depth, kernel_initializer=layers.XavierUniformInitializer()), layers.Dropout(rate=dropout, mode=mode) )
def WideResnet(n_blocks=3, d_hidden=64, n_output_classes=10, mode='train'): """WideResnet from https://arxiv.org/pdf/1605.07146.pdf. Args: n_blocks: int, number of blocks in a group. d_hidden: Dimensionality of the first hidden layer (multiplied later). n_output_classes: int, number of distinct output classes. mode: Whether we are training or evaluating or doing inference. Returns: The list of layers comprising a WideResnet model with the given parameters. """ del mode return tl.Model( tl.ToFloat(), tl.Conv(d_hidden, (3, 3), padding='SAME'), WideResnetGroup(n_blocks, d_hidden), WideResnetGroup(n_blocks, d_hidden * 2, (2, 2)), WideResnetGroup(n_blocks, d_hidden * 4, (2, 2)), tl.BatchNorm(), tl.Relu(), tl.AvgPool(pool_size=(8, 8)), tl.Flatten(), tl.Dense(n_output_classes), tl.LogSoftmax(), )
def PositionLookupTransformerLM(vocab_size=128, d_feature=256, d_feedforward=512, n_layers=3, n_heads=4, dropout=0.1, max_len=100, mode='train'): """Transformer language model (only uses the decoder part of Transformer). Args: vocab_size: int: vocab size d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_layers: int: number of layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: maximal length mode: str: 'train' or 'eval' Returns: the layer. """ positions = _POSITIONS[:max_len, :] return tl.Serial([ tl.ShiftRight(), tl.Embedding(d_feature, vocab_size), tl.Dropout(rate=dropout, mode=mode), NewPositionalEncoding(positions=positions), [DecoderLayer(positions, d_feature, d_feedforward, n_heads, dropout, mode) for _ in range(n_layers)], PreservePosition(tl.LayerNorm()), tl.Dense(vocab_size), tl.LogSoftmax() ])
def TransformerDecoder(d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train'): """Returns a Transformer decoder model. The input to the model is a continuous tensor. Does not shift the input to the right, i.e. the output for timestep t is based on inputs up to timestep t inclusively. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: A Transformer decoder as a layer that maps from a continuous tensor to a continuous tensor. """ return tl.Model( # vecs tl.PositionalEncoding(max_len=max_len), tl.Dense(d_model), # vecs [DecoderBlock(d_model, d_ff, n_heads, dropout, mode) for _ in range(n_layers)], # vecs tl.LayerNorm(), # vecs )
def WideResnet(n_blocks=3, widen_factor=1, n_output_classes=10, mode='train'): """WideResnet from https://arxiv.org/pdf/1605.07146.pdf. Args: n_blocks: int, number of blocks in a group. total layers = 6n + 4. widen_factor: int, widening factor of each group. k=1 is vanilla resnet. n_output_classes: int, number of distinct output classes. mode: Whether we are training or evaluating or doing inference. Returns: The list of layers comprising a WideResnet model with the given parameters. """ return tl.Model( tl.ToFloat(), tl.Conv(16, (3, 3), padding='SAME'), WideResnetGroup(n_blocks, 16 * widen_factor, mode=mode), WideResnetGroup(n_blocks, 32 * widen_factor, (2, 2), mode=mode), WideResnetGroup(n_blocks, 64 * widen_factor, (2, 2), mode=mode), tl.BatchNorm(mode=mode), tl.Relu(), tl.AvgPool(pool_size=(8, 8)), tl.Flatten(), tl.Dense(n_output_classes), tl.LogSoftmax(), )