Example #1
0
def AtariCnn(hidden_sizes=(32, 32), output_size=128):
    # Input's shape = (B, T, H, W, C)
    return tl.Serial(
        tl.Div(divisor=255.0),
        # Have 4 copies of the input, each one shifted to the right by one.
        tl.Branch(
            tl.NoOp(), tl.ShiftRight(),
            tl.Serial(
                tl.ShiftRight(),
                tl.ShiftRight(),
            ), tl.Serial(
                tl.ShiftRight(),
                tl.ShiftRight(),
                tl.ShiftRight(),
            )),
        # Concatenated on the last axis.
        tl.Concatenate(axis=-1),  # (B, T, H, W, 4C)
        tl.Rebatch(tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'), 2),
        tl.Relu(),
        tl.Rebatch(tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'), 2),
        tl.Relu(),
        tl.Flatten(num_axis_to_keep=2),  # B, T and rest.
        tl.Dense(output_size),
        tl.Relu(),
        # Eventually this is shaped (B, T, output_size)
    )
Example #2
0
def AtariCnn(hidden_sizes=(32, 32), output_size=128, mode='train'):
    """An Atari CNN."""
    del mode

    # TODO(jonni): Include link to paper?
    # Input shape: (B, T, H, W, C)
    # Output shape: (B, T, output_size)
    return tl.Model(
        tl.ToFloat(),
        tl.Div(divisor=255.0),

        # Set up 4 successive game frames, concatenated on the last axis.
        tl.Dup(),
        tl.Dup(),
        tl.Dup(),
        tl.Parallel(None, _shift_right(1), _shift_right(2), _shift_right(3)),
        tl.Concatenate(n_items=4, axis=-1),  # (B, T, H, W, 4C)
        tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'),
        tl.Relu(),
        tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'),
        tl.Relu(),
        tl.Flatten(n_axes_to_keep=2),  # B, T and rest.
        tl.Dense(output_size),
        tl.Relu(),
    )
Example #3
0
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
    """WideResnet convolutational block."""
    main = tl.Serial(tl.BatchNorm(), tl.Relu(),
                     tl.Conv(channels, (3, 3), strides, padding='SAME'),
                     tl.BatchNorm(), tl.Relu(),
                     tl.Conv(channels, (3, 3), padding='SAME'))
    shortcut = tl.Copy() if not channel_mismatch else tl.Conv(
        channels, (3, 3), strides, padding='SAME')
    return tl.Residual(main, shortcut=shortcut)
Example #4
0
def IdentityBlock(kernel_size, filters):
    """ResNet identical size block."""
    ks = kernel_size
    filters1, filters2, filters3 = filters
    main = tl.Serial(tl.Conv(filters1, (1, 1)), tl.BatchNorm(), tl.Relu(),
                     tl.Conv(filters2, (ks, ks), padding='SAME'),
                     tl.BatchNorm(), tl.Relu(), tl.Conv(filters3, (1, 1)),
                     tl.BatchNorm())
    return tl.Serial(tl.Residual(main), tl.Relu())
Example #5
0
def WideResnetBlock(channels, strides=(1, 1), mode='train'):
    """WideResnet convolutional block."""
    return [
        tl.BatchNorm(mode=mode),
        tl.Relu(),
        tl.Conv(channels, (3, 3), strides, padding='SAME'),
        tl.BatchNorm(mode=mode),
        tl.Relu(),
        tl.Conv(channels, (3, 3), padding='SAME'),
    ]
Example #6
0
def ConvBlock(kernel_size, filters, strides):
    """ResNet convolutional striding block."""
    ks = kernel_size
    filters1, filters2, filters3 = filters
    main = tl.Serial(tl.Conv(filters1, (1, 1), strides), tl.BatchNorm(),
                     tl.Relu(), tl.Conv(filters2, (ks, ks), padding='SAME'),
                     tl.BatchNorm(), tl.Relu(), tl.Conv(filters3, (1, 1)),
                     tl.BatchNorm())
    shortcut = tl.Serial(tl.Conv(filters3, (1, 1), strides), tl.BatchNorm())
    return tl.Serial(tl.Residual(main, shortcut=shortcut), tl.Relu())
Example #7
0
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
    """WideResnet convolutational block."""
    main = layers.Serial(
        layers.BatchNorm(), layers.Relu(),
        layers.Conv(channels, (3, 3), strides, padding='SAME'),
        layers.BatchNorm(), layers.Relu(),
        layers.Conv(channels, (3, 3), padding='SAME'))
    shortcut = layers.Identity() if not channel_mismatch else layers.Conv(
        channels, (3, 3), strides, padding='SAME')
    return layers.Serial(layers.Branch(), layers.Parallel(main, shortcut),
                         layers.SumBranches())
Example #8
0
def IdentityBlock(kernel_size, filters):
    """ResNet identical size block."""
    ks = kernel_size
    filters1, filters2, filters3 = filters
    main = layers.Serial(layers.Conv(filters1, (1, 1)), layers.BatchNorm(),
                         layers.Relu(),
                         layers.Conv(filters2, (ks, ks), padding='SAME'),
                         layers.BatchNorm(), layers.Relu(),
                         layers.Conv(filters3, (1, 1)), layers.BatchNorm())
    return layers.Serial(layers.Branch(),
                         layers.Parallel(main, layers.Identity()),
                         layers.SumBranches(), layers.Relu())
Example #9
0
def common_layers():
    cur_layers = []
    if FLAGS.flatten_non_batch_time_dims:
        cur_layers = [
            layers.Div(divisor=255.0),
            layers.Flatten(num_axis_to_keep=2)
        ]
    return cur_layers + [
        layers.Dense(16),
        layers.Relu(),
        layers.Dense(4),
        layers.Relu()
    ]
Example #10
0
def ConvBlock(kernel_size, filters, strides):
    """ResNet convolutional striding block."""
    ks = kernel_size
    filters1, filters2, filters3 = filters
    main = layers.Serial(layers.Conv(filters1, (1, 1), strides),
                         layers.BatchNorm(), layers.Relu(),
                         layers.Conv(filters2, (ks, ks), padding='SAME'),
                         layers.BatchNorm(), layers.Relu(),
                         layers.Conv(filters3, (1, 1)), layers.BatchNorm())
    shortcut = layers.Serial(layers.Conv(filters3, (1, 1), strides),
                             layers.BatchNorm())
    return layers.Serial(layers.Branch(), layers.Parallel(main, shortcut),
                         layers.SumBranches(), layers.Relu())
Example #11
0
def common_layers():
    cur_layers = []
    if FLAGS.env_name == "Pong-v0":
        cur_layers = [
            layers.Div(divisor=255.0),
            layers.Flatten(num_axis_to_keep=2)
        ]
    return cur_layers + [
        layers.Dense(16),
        layers.Relu(),
        layers.Dense(4),
        layers.Relu()
    ]
Example #12
0
def WideResnet(num_blocks=3, hidden_size=64, num_output_classes=10,
               mode='train'):
  """WideResnet from https://arxiv.org/pdf/1605.07146.pdf.

  Args:
    num_blocks: int, number of blocks in a group.
    hidden_size: the size of the first hidden layer (multiplied later).
    num_output_classes: int, number of classes to distinguish.
    mode: is it training or eval.

  Returns:
    The WideResnet model with given layer and output sizes.
  """
  del mode
  return tl.Serial(
      tl.Conv(hidden_size, (3, 3), padding='SAME'),
      WideResnetGroup(num_blocks, hidden_size),
      WideResnetGroup(num_blocks, hidden_size * 2, (2, 2)),
      WideResnetGroup(num_blocks, hidden_size * 4, (2, 2)),
      tl.BatchNorm(),
      tl.Relu(),
      tl.AvgPool(pool_size=(8, 8)),
      tl.Flatten(),
      tl.Dense(num_output_classes),
      tl.LogSoftmax()
  )
Example #13
0
def WideResnet(n_blocks=3, d_hidden=64, n_output_classes=10, mode='train'):
    """WideResnet from https://arxiv.org/pdf/1605.07146.pdf.

  Args:
    n_blocks: int, number of blocks in a group.
    d_hidden: Dimensionality of the first hidden layer (multiplied later).
    n_output_classes: int, number of distinct output classes.
    mode: Whether we are training or evaluating or doing inference.

  Returns:
    The list of layers comprising a WideResnet model with the given parameters.
  """
    del mode
    return tl.Model(
        tl.ToFloat(),
        tl.Conv(d_hidden, (3, 3), padding='SAME'),
        WideResnetGroup(n_blocks, d_hidden),
        WideResnetGroup(n_blocks, d_hidden * 2, (2, 2)),
        WideResnetGroup(n_blocks, d_hidden * 4, (2, 2)),
        tl.BatchNorm(),
        tl.Relu(),
        tl.AvgPool(pool_size=(8, 8)),
        tl.Flatten(),
        tl.Dense(n_output_classes),
        tl.LogSoftmax(),
    )
Example #14
0
def WideResnet(n_blocks=3, widen_factor=1, n_output_classes=10, mode='train'):
    """WideResnet from https://arxiv.org/pdf/1605.07146.pdf.

  Args:
    n_blocks: int, number of blocks in a group. total layers = 6n + 4.
    widen_factor: int, widening factor of each group. k=1 is vanilla resnet.
    n_output_classes: int, number of distinct output classes.
    mode: Whether we are training or evaluating or doing inference.

  Returns:
    The list of layers comprising a WideResnet model with the given parameters.
  """
    return tl.Model(
        tl.ToFloat(),
        tl.Conv(16, (3, 3), padding='SAME'),
        WideResnetGroup(n_blocks, 16 * widen_factor, mode=mode),
        WideResnetGroup(n_blocks, 32 * widen_factor, (2, 2), mode=mode),
        WideResnetGroup(n_blocks, 64 * widen_factor, (2, 2), mode=mode),
        tl.BatchNorm(mode=mode),
        tl.Relu(),
        tl.AvgPool(pool_size=(8, 8)),
        tl.Flatten(),
        tl.Dense(n_output_classes),
        tl.LogSoftmax(),
    )
Example #15
0
def ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode):
    """Residual feed-forward layer with normalization at start."""
    return layers.Residual(layers.LayerNorm(), layers.Dense(feedforward_depth),
                           layers.Relu(),
                           layers.Dropout(rate=dropout, mode=mode),
                           layers.Dense(feature_depth),
                           layers.Dropout(rate=dropout, mode=mode))
Example #16
0
def FrameStackMLP(n_frames=4, hidden_sizes=(64,), output_size=64,
                  mode='train'):
  """MLP operating on a fixed number of last frames."""
  del mode

  return tl.Model(
      FrameStack(n_frames=n_frames),
      [[tl.Dense(d_hidden), tl.Relu()] for d_hidden in hidden_sizes],
      tl.Dense(output_size),
  )
def ResidualFeedForward(d_feature, d_feedforward, dropout, mode):
  """Residual feed-forward layer with normalization at start."""
  return Residual(
      tl.LayerNorm(),
      tl.Dense(d_feedforward),
      tl.Relu(),
      tl.Dropout(rate=dropout, mode=mode),
      tl.Dense(d_feature),
      tl.Dropout(rate=dropout, mode=mode)
  )
Example #18
0
def FeedForward(d_model, d_ff, dropout, layer_idx, mode):
    """Feed-forward block with layer normalization at start."""
    return [
        tl.LayerNorm(),
        tl.Dense(d_ff),
        tl.Relu(),
        tl.Dropout(rate=dropout, name='ff_middle_%d' % layer_idx, mode=mode),
        tl.Dense(d_model),
        tl.Dropout(rate=dropout, name='ff_final_%d' % layer_idx, mode=mode),
    ]
Example #19
0
def FeedForward(d_feature, d_feedforward, dropout, mode):
    """Feed-forward block with layer normalization at start."""
    return [
        tl.LayerNorm(),
        tl.Dense(d_feedforward),
        tl.Relu(),
        tl.Dropout(rate=dropout, mode=mode),
        tl.Dense(d_feature),
        tl.Dropout(rate=dropout, mode=mode),
    ]
Example #20
0
def IdentityBlock(kernel_size, filters, mode='train'):
    """ResNet identical size block."""
    # TODO(jonni): Use good defaults so Resnet50 code is cleaner / less redundant.
    ks = kernel_size
    filters1, filters2, filters3 = filters
    main = [
        tl.Conv(filters1, (1, 1)),
        tl.BatchNorm(mode=mode),
        tl.Relu(),
        tl.Conv(filters2, (ks, ks), padding='SAME'),
        tl.BatchNorm(mode=mode),
        tl.Relu(),
        tl.Conv(filters3, (1, 1)),
        tl.BatchNorm(mode=mode),
    ]
    return [
        tl.Residual(main),
        tl.Relu(),
    ]
Example #21
0
def FeedForward(d_model, d_ff, dropout, mode):
    """Feed-forward block with layer normalization at start."""
    return [
        tl.LayerNorm(),
        tl.Dense(d_ff),
        BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
        tl.Relu(),
        tl.Dense(d_model),
        BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
    ]
Example #22
0
def AtariCnn(n_frames=4, hidden_sizes=(32, 32), output_size=128, mode='train'):
    """An Atari CNN."""
    del mode

    # TODO(jonni): Include link to paper?
    # Input shape: (B, T, H, W, C)
    # Output shape: (B, T, output_size)
    return tl.Model(
        tl.ToFloat(),
        tl.Div(divisor=255.0),

        # Set up n_frames successive game frames, concatenated on the last axis.
        FrameStack(n_frames=n_frames),  # (B, T, H, W, 4C)
        tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'),
        tl.Relu(),
        tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'),
        tl.Relu(),
        tl.Flatten(n_axes_to_keep=2),  # B, T and rest.
        tl.Dense(output_size),
        tl.Relu(),
    )
def FeedForward(d_model, d_ff, dropout, mode):
    """Feed-forward block with layer normalization at start."""
    # TODO(kitaev): add dropout. Dropout is typically performed by adding noise to
    # the activations, but when the size of the activations is very large it is
    # more efficient to add noise to the *parameters* instead.
    del dropout, mode
    return [
        tl.LayerNorm(),
        tl.Dense(d_ff),
        tl.Relu(),
        tl.Dense(d_model),
    ]
Example #24
0
def FeedForward(d_feature, d_feedforward, dropout, mode):
  """Feed-forward block with layer normalization at start."""
  # TODO(kitaev): dropout is disabled to save memory
  del dropout, mode
  return [
      tl.LayerNorm(),
      tl.Dense(d_feedforward),
      tl.Relu(),
      # tl.Dropout(rate=dropout, mode=mode),
      tl.Dense(d_feature),
      # tl.Dropout(rate=dropout, mode=mode),
  ]
Example #25
0
def Resnet50(d_hidden=64, n_output_classes=1001, mode='train'):
    """ResNet.

  Args:
    d_hidden: Dimensionality of the first hidden layer (multiplied later).
    n_output_classes: Number of distinct output classes.
    mode: Whether we are training or evaluating or doing inference.

  Returns:
    The list of layers comprising a ResNet model with the given parameters.
  """
    return tl.Model(
        tl.ToFloat(),
        tl.Conv(d_hidden, (7, 7), (2, 2), 'SAME'),
        tl.BatchNorm(mode=mode),
        tl.Relu(),
        tl.MaxPool(pool_size=(3, 3), strides=(2, 2)),
        ConvBlock(3, [d_hidden, d_hidden, 4 * d_hidden], (1, 1), mode=mode),
        IdentityBlock(3, [d_hidden, d_hidden, 4 * d_hidden], mode=mode),
        IdentityBlock(3, [d_hidden, d_hidden, 4 * d_hidden], mode=mode),
        ConvBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden], (2, 2),
                  mode=mode),
        IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden],
                      mode=mode),
        ConvBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], (2, 2),
                  mode=mode),
        IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden],
                      mode=mode),
        ConvBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden], (2, 2),
                  mode=mode),
        IdentityBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden],
                      mode=mode),
        tl.AvgPool(pool_size=(7, 7)),
        tl.Flatten(),
        tl.Dense(n_output_classes),
        tl.LogSoftmax(),
    )
Example #26
0
def ConvBlock(kernel_size, filters, strides, mode='train'):
    """ResNet convolutional striding block."""
    # TODO(jonni): Use good defaults so Resnet50 code is cleaner / less redundant.
    ks = kernel_size
    filters1, filters2, filters3 = filters
    main = [
        tl.Conv(filters1, (1, 1), strides),
        tl.BatchNorm(mode=mode),
        tl.Relu(),
        tl.Conv(filters2, (ks, ks), padding='SAME'),
        tl.BatchNorm(mode=mode),
        tl.Relu(),
        tl.Conv(filters3, (1, 1)),
        tl.BatchNorm(mode=mode),
    ]
    shortcut = [
        tl.Conv(filters3, (1, 1), strides),
        tl.BatchNorm(mode=mode),
    ]
    return [
        tl.Residual(main, shortcut=shortcut),
        tl.Relu(),
    ]
Example #27
0
def ResidualFeedForward(d_feature,
                        d_feedforward,
                        dropout,
                        mode):
  """Residual feed-forward layer with normalization at start."""
  stack = tl.Serial(
      tl.LayerNorm(),
      tl.Dense(d_feedforward),
      tl.Relu(),
      tl.Dropout(rate=dropout, mode=mode),
      tl.Dense(d_feature),
      tl.Dropout(rate=dropout, mode=mode)
  )
  return tl.Residual(PreservePosition(stack))
Example #28
0
def ResidualFeedForward(feature_depth,
                        feedforward_depth,
                        dropout,
                        mode):
  """Residual feed-forward layer with normalization at start."""
  return layers.Residual(
      layers.LayerNorm(),
      layers.Dense(feedforward_depth,
                   kernel_initializer=layers.XavierUniformInitializer()),
      layers.Relu(),
      layers.Dropout(rate=dropout, mode=mode),
      layers.Dense(feature_depth,
                   kernel_initializer=layers.XavierUniformInitializer()),
      layers.Dropout(rate=dropout, mode=mode)
  )
Example #29
0
def Resnet50(hidden_size=64, num_output_classes=1001, mode='train'):
    """ResNet.

  Args:
    hidden_size: the size of the first hidden layer (multiplied later).
    num_output_classes: how many classes to distinguish.
    mode: whether we are training or evaluating or doing inference.

  Returns:
    The ResNet model with the given layer and output sizes.
  """
    del mode
    return tl.Serial(
        tl.Conv(hidden_size, (7, 7), (2, 2),
                'SAME'), tl.BatchNorm(), tl.Relu(),
        tl.MaxPool(pool_size=(3, 3), strides=(2, 2)),
        ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)),
        IdentityBlock(3, [hidden_size, hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [hidden_size, hidden_size, 4 * hidden_size]),
        ConvBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size],
                  (2, 2)),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]),
        ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size],
                  (2, 2)),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size],
                  (2, 2)),
        IdentityBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size]),
        IdentityBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size]),
        tl.AvgPool(pool_size=(7, 7)), tl.Flatten(),
        tl.Dense(num_output_classes), tl.LogSoftmax())