def AtariCnn(hidden_sizes=(32, 32), output_size=128): # Input's shape = (B, T, H, W, C) return tl.Serial( tl.Div(divisor=255.0), # Have 4 copies of the input, each one shifted to the right by one. tl.Branch( tl.NoOp(), tl.ShiftRight(), tl.Serial( tl.ShiftRight(), tl.ShiftRight(), ), tl.Serial( tl.ShiftRight(), tl.ShiftRight(), tl.ShiftRight(), )), # Concatenated on the last axis. tl.Concatenate(axis=-1), # (B, T, H, W, 4C) tl.Rebatch(tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'), 2), tl.Relu(), tl.Rebatch(tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'), 2), tl.Relu(), tl.Flatten(num_axis_to_keep=2), # B, T and rest. tl.Dense(output_size), tl.Relu(), # Eventually this is shaped (B, T, output_size) )
def AtariCnn(hidden_sizes=(32, 32), output_size=128, mode='train'): """An Atari CNN.""" del mode # TODO(jonni): Include link to paper? # Input shape: (B, T, H, W, C) # Output shape: (B, T, output_size) return tl.Model( tl.ToFloat(), tl.Div(divisor=255.0), # Set up 4 successive game frames, concatenated on the last axis. tl.Dup(), tl.Dup(), tl.Dup(), tl.Parallel(None, _shift_right(1), _shift_right(2), _shift_right(3)), tl.Concatenate(n_items=4, axis=-1), # (B, T, H, W, 4C) tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'), tl.Relu(), tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'), tl.Relu(), tl.Flatten(n_axes_to_keep=2), # B, T and rest. tl.Dense(output_size), tl.Relu(), )
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False): """WideResnet convolutational block.""" main = tl.Serial(tl.BatchNorm(), tl.Relu(), tl.Conv(channels, (3, 3), strides, padding='SAME'), tl.BatchNorm(), tl.Relu(), tl.Conv(channels, (3, 3), padding='SAME')) shortcut = tl.Copy() if not channel_mismatch else tl.Conv( channels, (3, 3), strides, padding='SAME') return tl.Residual(main, shortcut=shortcut)
def IdentityBlock(kernel_size, filters): """ResNet identical size block.""" ks = kernel_size filters1, filters2, filters3 = filters main = tl.Serial(tl.Conv(filters1, (1, 1)), tl.BatchNorm(), tl.Relu(), tl.Conv(filters2, (ks, ks), padding='SAME'), tl.BatchNorm(), tl.Relu(), tl.Conv(filters3, (1, 1)), tl.BatchNorm()) return tl.Serial(tl.Residual(main), tl.Relu())
def ConvBlock(kernel_size, filters, strides): """ResNet convolutional striding block.""" ks = kernel_size filters1, filters2, filters3 = filters main = tl.Serial(tl.Conv(filters1, (1, 1), strides), tl.BatchNorm(), tl.Relu(), tl.Conv(filters2, (ks, ks), padding='SAME'), tl.BatchNorm(), tl.Relu(), tl.Conv(filters3, (1, 1)), tl.BatchNorm()) shortcut = tl.Serial(tl.Conv(filters3, (1, 1), strides), tl.BatchNorm()) return tl.Serial(tl.Residual(main, shortcut=shortcut), tl.Relu())
def WideResnetBlock(channels, strides=(1, 1), mode='train'): """WideResnet convolutional block.""" return [ tl.BatchNorm(mode=mode), tl.Relu(), tl.Conv(channels, (3, 3), strides, padding='SAME'), tl.BatchNorm(mode=mode), tl.Relu(), tl.Conv(channels, (3, 3), padding='SAME'), ]
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False): """WideResnet convolutational block.""" main = layers.Serial( layers.BatchNorm(), layers.Relu(), layers.Conv(channels, (3, 3), strides, padding='SAME'), layers.BatchNorm(), layers.Relu(), layers.Conv(channels, (3, 3), padding='SAME')) shortcut = layers.Identity() if not channel_mismatch else layers.Conv( channels, (3, 3), strides, padding='SAME') return layers.Serial(layers.Branch(), layers.Parallel(main, shortcut), layers.SumBranches())
def IdentityBlock(kernel_size, filters): """ResNet identical size block.""" ks = kernel_size filters1, filters2, filters3 = filters main = layers.Serial(layers.Conv(filters1, (1, 1)), layers.BatchNorm(), layers.Relu(), layers.Conv(filters2, (ks, ks), padding='SAME'), layers.BatchNorm(), layers.Relu(), layers.Conv(filters3, (1, 1)), layers.BatchNorm()) return layers.Serial(layers.Branch(), layers.Parallel(main, layers.Identity()), layers.SumBranches(), layers.Relu())
def ConvBlock(kernel_size, filters, strides): """ResNet convolutional striding block.""" ks = kernel_size filters1, filters2, filters3 = filters main = layers.Serial(layers.Conv(filters1, (1, 1), strides), layers.BatchNorm(), layers.Relu(), layers.Conv(filters2, (ks, ks), padding='SAME'), layers.BatchNorm(), layers.Relu(), layers.Conv(filters3, (1, 1)), layers.BatchNorm()) shortcut = layers.Serial(layers.Conv(filters3, (1, 1), strides), layers.BatchNorm()) return layers.Serial(layers.Branch(), layers.Parallel(main, shortcut), layers.SumBranches(), layers.Relu())
def WideResnet(num_blocks=3, hidden_size=64, num_output_classes=10, mode='train'): """WideResnet from https://arxiv.org/pdf/1605.07146.pdf. Args: num_blocks: int, number of blocks in a group. hidden_size: the size of the first hidden layer (multiplied later). num_output_classes: int, number of classes to distinguish. mode: is it training or eval. Returns: The WideResnet model with given layer and output sizes. """ del mode return tl.Serial( tl.Conv(hidden_size, (3, 3), padding='SAME'), WideResnetGroup(num_blocks, hidden_size), WideResnetGroup(num_blocks, hidden_size * 2, (2, 2)), WideResnetGroup(num_blocks, hidden_size * 4, (2, 2)), tl.BatchNorm(), tl.Relu(), tl.AvgPool(pool_size=(8, 8)), tl.Flatten(), tl.Dense(num_output_classes), tl.LogSoftmax() )
def WideResnet(n_blocks=3, widen_factor=1, n_output_classes=10, mode='train'): """WideResnet from https://arxiv.org/pdf/1605.07146.pdf. Args: n_blocks: int, number of blocks in a group. total layers = 6n + 4. widen_factor: int, widening factor of each group. k=1 is vanilla resnet. n_output_classes: int, number of distinct output classes. mode: Whether we are training or evaluating or doing inference. Returns: The list of layers comprising a WideResnet model with the given parameters. """ return tl.Model( tl.ToFloat(), tl.Conv(16, (3, 3), padding='SAME'), WideResnetGroup(n_blocks, 16 * widen_factor, mode=mode), WideResnetGroup(n_blocks, 32 * widen_factor, (2, 2), mode=mode), WideResnetGroup(n_blocks, 64 * widen_factor, (2, 2), mode=mode), tl.BatchNorm(mode=mode), tl.Relu(), tl.AvgPool(pool_size=(8, 8)), tl.Flatten(), tl.Dense(n_output_classes), tl.LogSoftmax(), )
def WideResnet(n_blocks=3, d_hidden=64, n_output_classes=10, mode='train'): """WideResnet from https://arxiv.org/pdf/1605.07146.pdf. Args: n_blocks: int, number of blocks in a group. d_hidden: Dimensionality of the first hidden layer (multiplied later). n_output_classes: int, number of distinct output classes. mode: Whether we are training or evaluating or doing inference. Returns: The list of layers comprising a WideResnet model with the given parameters. """ del mode return tl.Model( tl.ToFloat(), tl.Conv(d_hidden, (3, 3), padding='SAME'), WideResnetGroup(n_blocks, d_hidden), WideResnetGroup(n_blocks, d_hidden * 2, (2, 2)), WideResnetGroup(n_blocks, d_hidden * 4, (2, 2)), tl.BatchNorm(), tl.Relu(), tl.AvgPool(pool_size=(8, 8)), tl.Flatten(), tl.Dense(n_output_classes), tl.LogSoftmax(), )
def WideResnetGroup(n, channels, strides=(1, 1)): shortcut = [ tl.Conv(channels, (3, 3), strides, padding='SAME'), ] return [ tl.Residual(WideResnetBlock(channels, strides), shortcut=shortcut), tl.Residual([WideResnetBlock(channels, (1, 1)) for _ in range(n - 1)]), ]
def IdentityBlock(kernel_size, filters, mode='train'): """ResNet identical size block.""" # TODO(jonni): Use good defaults so Resnet50 code is cleaner / less redundant. ks = kernel_size filters1, filters2, filters3 = filters main = [ tl.Conv(filters1, (1, 1)), tl.BatchNorm(mode=mode), tl.Relu(), tl.Conv(filters2, (ks, ks), padding='SAME'), tl.BatchNorm(mode=mode), tl.Relu(), tl.Conv(filters3, (1, 1)), tl.BatchNorm(mode=mode), ] return [ tl.Residual(main), tl.Relu(), ]
def AtariCnn(n_frames=4, hidden_sizes=(32, 32), output_size=128, mode='train'): """An Atari CNN.""" del mode # TODO(jonni): Include link to paper? # Input shape: (B, T, H, W, C) # Output shape: (B, T, output_size) return tl.Model( tl.ToFloat(), tl.Div(divisor=255.0), # Set up n_frames successive game frames, concatenated on the last axis. FrameStack(n_frames=n_frames), # (B, T, H, W, 4C) tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'), tl.Relu(), tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'), tl.Relu(), tl.Flatten(n_axes_to_keep=2), # B, T and rest. tl.Dense(output_size), tl.Relu(), )
def Resnet50(d_hidden=64, n_output_classes=1001, mode='train'): """ResNet. Args: d_hidden: Dimensionality of the first hidden layer (multiplied later). n_output_classes: Number of distinct output classes. mode: Whether we are training or evaluating or doing inference. Returns: The list of layers comprising a ResNet model with the given parameters. """ return tl.Model( tl.ToFloat(), tl.Conv(d_hidden, (7, 7), (2, 2), 'SAME'), tl.BatchNorm(mode=mode), tl.Relu(), tl.MaxPool(pool_size=(3, 3), strides=(2, 2)), ConvBlock(3, [d_hidden, d_hidden, 4 * d_hidden], (1, 1), mode=mode), IdentityBlock(3, [d_hidden, d_hidden, 4 * d_hidden], mode=mode), IdentityBlock(3, [d_hidden, d_hidden, 4 * d_hidden], mode=mode), ConvBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden], (2, 2), mode=mode), IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden], mode=mode), IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden], mode=mode), IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden], mode=mode), ConvBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], (2, 2), mode=mode), IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], mode=mode), IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], mode=mode), IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], mode=mode), IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], mode=mode), IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], mode=mode), ConvBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden], (2, 2), mode=mode), IdentityBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden], mode=mode), IdentityBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden], mode=mode), tl.AvgPool(pool_size=(7, 7)), tl.Flatten(), tl.Dense(n_output_classes), tl.LogSoftmax(), )
def ConvBlock(kernel_size, filters, strides, mode='train'): """ResNet convolutional striding block.""" # TODO(jonni): Use good defaults so Resnet50 code is cleaner / less redundant. ks = kernel_size filters1, filters2, filters3 = filters main = [ tl.Conv(filters1, (1, 1), strides), tl.BatchNorm(mode=mode), tl.Relu(), tl.Conv(filters2, (ks, ks), padding='SAME'), tl.BatchNorm(mode=mode), tl.Relu(), tl.Conv(filters3, (1, 1)), tl.BatchNorm(mode=mode), ] shortcut = [ tl.Conv(filters3, (1, 1), strides), tl.BatchNorm(mode=mode), ] return [ tl.Residual(main, shortcut=shortcut), tl.Relu(), ]
def WideResnetGroup(n, channels, strides=(1, 1), bn_momentum=0.9, mode='train'): shortcut = [ tl.Conv(channels, (3, 3), strides, padding='SAME'), ] return [ tl.Residual(WideResnetBlock(channels, strides, bn_momentum=bn_momentum, mode=mode), shortcut=shortcut), tl.Residual([ WideResnetBlock(channels, (1, 1), bn_momentum=bn_momentum, mode=mode) for _ in range(n - 1) ]), ]
def Resnet50(hidden_size=64, num_output_classes=1001, mode='train'): """ResNet. Args: hidden_size: the size of the first hidden layer (multiplied later). num_output_classes: how many classes to distinguish. mode: whether we are training or evaluating or doing inference. Returns: The ResNet model with the given layer and output sizes. """ del mode return tl.Serial( tl.Conv(hidden_size, (7, 7), (2, 2), 'SAME'), tl.BatchNorm(), tl.Relu(), tl.MaxPool(pool_size=(3, 3), strides=(2, 2)), ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)), IdentityBlock(3, [hidden_size, hidden_size, 4 * hidden_size]), IdentityBlock(3, [hidden_size, hidden_size, 4 * hidden_size]), ConvBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size], (2, 2)), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]), ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size], (2, 2)), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size], (2, 2)), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size]), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size]), tl.AvgPool(pool_size=(7, 7)), tl.Flatten(), tl.Dense(num_output_classes), tl.LogSoftmax())
def BuildConv(): return tl.Conv(filters=units, kernel_size=kernel_size, padding='SAME')