def WideResnetBlock(channels, strides=(1, 1), bn_momentum=0.9, mode='train'): """WideResnet convolutional block.""" return [ tl.BatchNorm(momentum=bn_momentum, mode=mode), tl.Relu(), tl.Conv(channels, (3, 3), strides, padding='SAME'), tl.BatchNorm(momentum=bn_momentum, mode=mode), tl.Relu(), tl.Conv(channels, (3, 3), padding='SAME'), ]
def WideResnet(n_blocks=3, widen_factor=1, n_output_classes=10, bn_momentum=0.9, mode='train'): """WideResnet from https://arxiv.org/pdf/1605.07146.pdf. Args: n_blocks: int, number of blocks in a group. total layers = 6n + 4. widen_factor: int, widening factor of each group. k=1 is vanilla resnet. n_output_classes: int, number of distinct output classes. bn_momentum: float, momentum in BatchNorm. mode: Whether we are training or evaluating or doing inference. Returns: The list of layers comprising a WideResnet model with the given parameters. """ return tl.Serial( tl.ToFloat(), tl.Conv(16, (3, 3), padding='SAME'), WideResnetGroup(n_blocks, 16 * widen_factor, bn_momentum=bn_momentum, mode=mode), WideResnetGroup(n_blocks, 32 * widen_factor, (2, 2), bn_momentum=bn_momentum, mode=mode), WideResnetGroup(n_blocks, 64 * widen_factor, (2, 2), bn_momentum=bn_momentum, mode=mode), tl.BatchNorm(momentum=bn_momentum, mode=mode), tl.Relu(), tl.AvgPool(pool_size=(8, 8)), tl.Flatten(), tl.Dense(n_output_classes), tl.LogSoftmax(), )
def test_call_rebatch(self): layer = tl.Conv(30, (3, 3)) x = np.ones((2, 9, 5, 5, 20)) layer.init(shapes.signature(x)) y = layer(x) self.assertEqual(y.shape, (2, 9, 3, 3, 30))
def LocallyConvDense(n_modules, n_units, kernel_size=1, length_kernel_size=1): """Layer using local convolutions for approximation of Dense layer. The layer splits the last axis of a tensor into `n_modules`, then runs a convolution on all those modules, and concatenates their results. It is similar to LocallyConnectedDense above, but shares weights. Args: n_modules: Indicates how many modules (pixels) should be input and output split into for processing. n_units: how many outputs (filters) should each module generate. kernel_size: The size of the kernel to be used. length_kernel_size: If > 1, also do causal convolution on the previous axis, which is often the sentence length in sequence models. Returns: LocallyConvDense base.Layer. """ if n_modules == 1: return tl.Dense(n_units) if kernel_size % 2 != 1: raise ValueError('Currently we only handle odd kernel sizes.') half = (kernel_size - 1) // 2 pad_widths = [[0, 0], [length_kernel_size - 1, 0], [half, half], [0, 0]] return tl.Serial( tl.SplitLastAxis(n_modules), tl.Fn('Pad', lambda x: jnp.pad(x, pad_width=pad_widths)), tl.Conv(n_units, kernel_size=(length_kernel_size, kernel_size)), tl.MergeLastTwoAxes())
def test_use_bias_false(self): layer = tl.Conv(30, (3, 3), use_bias=False) x = np.ones((9, 5, 5, 20)) layer.init(shapes.signature(x)) y = layer(x) self.assertEqual(y.shape, (9, 3, 3, 30)) # With use_bias=False, layer.weights is just 'w' and there is no 'b'. self.assertEqual(layer.weights.shape, (3, 3, 20, 30))
def IdentityBlock(kernel_size, filters, norm, non_linearity, mode='train'): """ResNet identical size block.""" ks = kernel_size filters1, filters2, filters3 = filters main = [ tl.Conv(filters1, (1, 1)), norm(mode=mode), non_linearity(), tl.Conv(filters2, (ks, ks), padding='SAME'), norm(mode=mode), non_linearity(), tl.Conv(filters3, (1, 1)), norm(mode=mode), ] return [ tl.Residual(main), non_linearity(), ]
def AtariCnn(n_frames=4, hidden_sizes=(32, 32), output_size=128, mode='train'): """An Atari CNN.""" del mode # TODO(jonni): Include link to paper? # Input shape: (B, T, H, W, C) # Output shape: (B, T, output_size) return tl.Serial( tl.Fn(lambda x: x / 255.0), # Convert unsigned bytes to float. _FrameStack(n_frames=n_frames), # (B, T, H, W, 4C) tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'), tl.Relu(), tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'), tl.Relu(), tl.Flatten(n_axes_to_keep=2), # B, T and rest. tl.Dense(output_size), tl.Relu(), )
def Resnet50(d_hidden=64, n_output_classes=1001, mode='train', norm=tl.BatchNorm, non_linearity=tl.Relu): """ResNet. Args: d_hidden: Dimensionality of the first hidden layer (multiplied later). n_output_classes: Number of distinct output classes. mode: Whether we are training or evaluating or doing inference. norm: `Layer` used for normalization, Ex: BatchNorm or FilterResponseNorm. non_linearity: `Layer` used as a non-linearity, Ex: If norm is BatchNorm then this is a Relu, otherwise for FilterResponseNorm this should be ThresholdedLinearUnit. Returns: The list of layers comprising a ResNet model with the given parameters. """ # A ConvBlock configured with the given norm, non-linearity and mode. def Resnet50ConvBlock(filter_multiplier=1, strides=(2, 2)): filters = ([ filter_multiplier * dim for dim in [d_hidden, d_hidden, 4 * d_hidden] ]) return ConvBlock(3, filters, strides, norm, non_linearity, mode) # Same as above for IdentityBlock. def Resnet50IdentityBlock(filter_multiplier=1): filters = ([ filter_multiplier * dim for dim in [d_hidden, d_hidden, 4 * d_hidden] ]) return IdentityBlock(3, filters, norm, non_linearity, mode) return tl.Serial( tl.ToFloat(), tl.Conv(d_hidden, (7, 7), (2, 2), 'SAME'), norm(mode=mode), non_linearity(), tl.MaxPool(pool_size=(3, 3), strides=(2, 2)), Resnet50ConvBlock(strides=(1, 1)), [Resnet50IdentityBlock() for _ in range(2)], Resnet50ConvBlock(2), [Resnet50IdentityBlock(2) for _ in range(3)], Resnet50ConvBlock(4), [Resnet50IdentityBlock(4) for _ in range(5)], Resnet50ConvBlock(8), [Resnet50IdentityBlock(8) for _ in range(2)], tl.AvgPool(pool_size=(7, 7)), tl.Flatten(), tl.Dense(n_output_classes), tl.LogSoftmax(), )
def get_model(n_output_classes=10): """ Simple CNN to classify Fashion MNIST """ model = tl.Serial( tl.ToFloat(), tl.Conv(32, (3, 3), (1, 1), "SAME"), tl.LayerNorm(), tl.Relu(), tl.MaxPool(), tl.Conv(64, (3, 3), (1, 1), "SAME"), tl.LayerNorm(), tl.Relu(), tl.MaxPool(), tl.Flatten(), tl.Dense(n_output_classes), ) return model
def IdentityBlock(kernel_size, filters, mode='train'): """ResNet identical size block.""" # TODO(jonni): Use good defaults so Resnet50 code is cleaner / less redundant. ks = kernel_size filters1, filters2, filters3 = filters main = [ tl.Conv(filters1, (1, 1)), tl.BatchNorm(mode=mode), tl.Relu(), tl.Conv(filters2, (ks, ks), padding='SAME'), tl.BatchNorm(mode=mode), tl.Relu(), tl.Conv(filters3, (1, 1)), tl.BatchNorm(mode=mode), ] return [ tl.Residual(main), tl.Relu(), ]
def AtariCnn(n_frames=4, hidden_sizes=(32, 32), output_size=128, mode='train'): """An Atari CNN.""" del mode # TODO(jonni): Include link to paper? # Input shape: (B, T, H, W, C) # Output shape: (B, T, output_size) return tl.Model( tl.ToFloat(), tl.Div(divisor=255.0), # Set up n_frames successive game frames, concatenated on the last axis. FrameStack(n_frames=n_frames), # (B, T, H, W, 4C) tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'), tl.Relu(), tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'), tl.Relu(), tl.Flatten(n_axes_to_keep=2), # B, T and rest. tl.Dense(output_size), tl.Relu(), )
def test_use_bias_true(self): layer = tl.Conv(30, (3, 3), use_bias=True) x = np.ones((9, 5, 5, 20)) layer.init(shapes.signature(x)) y = layer(x) self.assertEqual(y.shape, (9, 3, 3, 30)) self.assertIsInstance(layer.weights, tuple) self.assertLen(layer.weights, 2) self.assertEqual(layer.weights[0].shape, (3, 3, 20, 30)) self.assertEqual(layer.weights[1].shape, (30, ))
def WideResnetGroup(n, channels, strides=(1, 1), bn_momentum=0.9, mode='train'): shortcut = [ tl.Conv(channels, (3, 3), strides, padding='SAME'), ] return [ tl.Residual(WideResnetBlock(channels, strides, bn_momentum=bn_momentum, mode=mode), shortcut=shortcut), tl.Residual([WideResnetBlock(channels, (1, 1), bn_momentum=bn_momentum, mode=mode) for _ in range(n - 1)]), ]
def Resnet50(d_hidden=64, n_output_classes=1001, mode='train'): """ResNet. Args: d_hidden: Dimensionality of the first hidden layer (multiplied later). n_output_classes: Number of distinct output classes. mode: Whether we are training or evaluating or doing inference. Returns: The list of layers comprising a ResNet model with the given parameters. """ return tl.Model( tl.ToFloat(), tl.Conv(d_hidden, (7, 7), (2, 2), 'SAME'), tl.BatchNorm(mode=mode), tl.Relu(), tl.MaxPool(pool_size=(3, 3), strides=(2, 2)), ConvBlock(3, [d_hidden, d_hidden, 4 * d_hidden], (1, 1), mode=mode), IdentityBlock(3, [d_hidden, d_hidden, 4 * d_hidden], mode=mode), IdentityBlock(3, [d_hidden, d_hidden, 4 * d_hidden], mode=mode), ConvBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden], (2, 2), mode=mode), IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden], mode=mode), IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden], mode=mode), IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden], mode=mode), ConvBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], (2, 2), mode=mode), IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], mode=mode), IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], mode=mode), IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], mode=mode), IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], mode=mode), IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], mode=mode), ConvBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden], (2, 2), mode=mode), IdentityBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden], mode=mode), IdentityBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden], mode=mode), tl.AvgPool(pool_size=(7, 7)), tl.Flatten(), tl.Dense(n_output_classes), tl.LogSoftmax(), )
def ConvBlock(kernel_size, filters, strides, mode='train'): """ResNet convolutional striding block.""" # TODO(jonni): Use good defaults so Resnet50 code is cleaner / less redundant. ks = kernel_size filters1, filters2, filters3 = filters main = [ tl.Conv(filters1, (1, 1), strides), tl.BatchNorm(mode=mode), tl.Relu(), tl.Conv(filters2, (ks, ks), padding='SAME'), tl.BatchNorm(mode=mode), tl.Relu(), tl.Conv(filters3, (1, 1)), tl.BatchNorm(mode=mode), ] shortcut = [ tl.Conv(filters3, (1, 1), strides), tl.BatchNorm(mode=mode), ] return [ tl.Residual(main, shortcut=shortcut), tl.Relu(), ]
def ConvBlock(kernel_size, filters, strides, norm, non_linearity, mode='train'): """ResNet convolutional striding block.""" ks = kernel_size filters1, filters2, filters3 = filters main = [ tl.Conv(filters1, (1, 1), strides), norm(mode=mode), non_linearity(), tl.Conv(filters2, (ks, ks), padding='SAME'), norm(mode=mode), non_linearity(), tl.Conv(filters3, (1, 1)), norm(mode=mode), ] shortcut = [ tl.Conv(filters3, (1, 1), strides), norm(mode=mode), ] return [tl.Residual(main, shortcut=shortcut), non_linearity()]
def LocallyConvDense(n_modules, n_units, kernel_size=1): """Layer using local convolutions for approximation of Dense layer. The layer splits the last axis of a tensor into `n_modules`, then runs a convolution on all those modules, and concatenates their results. It is similar to LocallyConnectedDense above, but shares weights. Args: n_modules: Indicates how many modules (pixels) should be input and output split into for processing. n_units: how many outputs (filters) should each module generate. kernel_size: The size of the kernel to be used. Returns: LocallyConvDense base.Layer. """ if n_modules == 1: return tl.Dense(n_units) return tl.Serial( tl.SplitLastAxis(n_modules), tl.Conv(n_units, kernel_size=(1, kernel_size), padding='SAME'), tl.MergeLastTwoAxes())
def BuildConv(): return tl.Conv(filters=units, kernel_size=kernel_size, padding='SAME')