def EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer encoder layer. The input to the encoder is a pair (embedded source, mask) where the mask is created from the original source to prevent attending to the padding part of the input. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a pair (actiavtions, mask). """ return tl.Serial( tl.Residual( # Attention block here. tl.Parallel(tl.LayerNorm(), tl.Copy()), tl.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Parallel(tl.Dropout(rate=dropout, mode=mode), tl.Copy())), tl.Parallel( ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode), tl.Div( divisor=2.0) # Mask added to itself in the residual, divide. ))
def AtariCnn(hidden_sizes=(32, 32), output_size=128, mode='train'): """An Atari CNN.""" del mode # TODO(jonni): Include link to paper? # Input shape: (B, T, H, W, C) # Output shape: (B, T, output_size) return tl.Model( tl.ToFloat(), tl.Div(divisor=255.0), # Set up 4 successive game frames, concatenated on the last axis. tl.Dup(), tl.Dup(), tl.Dup(), tl.Parallel(None, _shift_right(1), _shift_right(2), _shift_right(3)), tl.Concatenate(n_items=4, axis=-1), # (B, T, H, W, 4C) tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'), tl.Relu(), tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'), tl.Relu(), tl.Flatten(n_axes_to_keep=2), # B, T and rest. tl.Dense(output_size), tl.Relu(), )
def AtariCnn(hidden_sizes=(32, 32), output_size=128): # Input's shape = (B, T, H, W, C) return tl.Serial( tl.Div(divisor=255.0), # Have 4 copies of the input, each one shifted to the right by one. tl.Branch( tl.NoOp(), tl.ShiftRight(), tl.Serial( tl.ShiftRight(), tl.ShiftRight(), ), tl.Serial( tl.ShiftRight(), tl.ShiftRight(), tl.ShiftRight(), )), # Concatenated on the last axis. tl.Concatenate(axis=-1), # (B, T, H, W, 4C) tl.Rebatch(tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'), 2), tl.Relu(), tl.Rebatch(tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'), 2), tl.Relu(), tl.Flatten(num_axis_to_keep=2), # B, T and rest. tl.Dense(output_size), tl.Relu(), # Eventually this is shaped (B, T, output_size) )
def common_layers(): cur_layers = [] if FLAGS.flatten_non_batch_time_dims: cur_layers = [ layers.Div(divisor=255.0), layers.Flatten(num_axis_to_keep=2) ] body = [layers.Dense(64), layers.Tanh(), layers.Dense(64), layers.Tanh()] return cur_layers + body
def test_div(self): layer = layers.Div(divisor=2.0) input_np = onp.array([[1, 2, 3], [4, 5, 6]], dtype=onp.float32) output_np = layer(input_np) # absltest doesn't have ndarray equalities. expected_output_np = input_np / 2.0 self.assertAlmostEqual(0.0, onp.sum((output_np - expected_output_np)**2), delta=1e-6)
def common_layers(): # TODO(afrozm): Refactor. if "NoFrameskip" in FLAGS.env_problem_name: return atari_layers() cur_layers = [] if FLAGS.flatten_dims: cur_layers = [ layers.Div(divisor=255.0), layers.Flatten(num_axis_to_keep=2) ] body = [layers.Dense(64), layers.Tanh(), layers.Dense(64), layers.Tanh()] return cur_layers + body
def common_layers(): cur_layers = [] if FLAGS.env_name == "Pong-v0": cur_layers = [ layers.Div(divisor=255.0), layers.Flatten(num_axis_to_keep=2) ] return cur_layers + [ layers.Dense(16), layers.Relu(), layers.Dense(4), layers.Relu() ]
def AtariCnn(n_frames=4, hidden_sizes=(32, 32), output_size=128, mode='train'): """An Atari CNN.""" del mode # TODO(jonni): Include link to paper? # Input shape: (B, T, H, W, C) # Output shape: (B, T, output_size) return tl.Model( tl.ToFloat(), tl.Div(divisor=255.0), # Set up n_frames successive game frames, concatenated on the last axis. FrameStack(n_frames=n_frames), # (B, T, H, W, 4C) tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'), tl.Relu(), tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'), tl.Relu(), tl.Flatten(n_axes_to_keep=2), # B, T and rest. tl.Dense(output_size), tl.Relu(), )