Example #1
0
 def MakeMain(input_shape):
     # the number of output channels depends on the number of input channels
     return stax.serial(stax.Conv(filters1, (1, 1)), stax.BatchNorm(),
                        stax.Relu,
                        stax.Conv(filters2, (ks, ks), padding='SAME'),
                        stax.BatchNorm(), stax.Relu,
                        stax.Conv(input_shape[3], (1, 1)), stax.BatchNorm())
Example #2
0
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
    """WideResnet convolutational block."""
    main = stax.serial(stax.BatchNorm(), stax.Relu,
                       stax.Conv(channels, (3, 3), strides, padding='SAME'),
                       stax.BatchNorm(), stax.Relu,
                       stax.Conv(channels, (3, 3), padding='SAME'))
    shortcut = stax.Identity if not channel_mismatch else stax.Conv(
        channels, (3, 3), strides, padding='SAME')
    return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut),
                       stax.FanInSum)
Example #3
0
 def ConvBlock(self, kernel_size, filters, strides=(2, 2)):
     filters1, filters2, filters3 = filters
     Main = stax.serial(
         stax.Conv(filters1, (1, 1), strides), stax.BatchNorm(), stax.Relu,
         stax.Conv(filters2, (kernel_size, kernel_size), padding='SAME'),
         stax.BatchNorm(), stax.Relu, stax.Conv(filters3, (1, 1)),
         stax.BatchNorm())
     Shortcut = stax.serial(stax.Conv(filters3, (1, 1), strides),
                            stax.BatchNorm())
     return stax.serial(stax.FanOut(2), stax.parallel(Main, Shortcut),
                        stax.FanInSum, stax.Relu)
Example #4
0
def create_double_conv(d: int, out_channels: int, mid_channels: int,
                       batch_norm: bool, activation: Callable):
    return stax.serial(
        CONV[d](mid_channels, (3, ) * d, padding='same'),
        stax.BatchNorm(axis=tuple(range(d +
                                        1))) if batch_norm else stax.Identity,
        activation,
        CONV[d](out_channels, (3, ) * d, padding='same'),
        stax.BatchNorm(axis=tuple(range(d +
                                        1))) if batch_norm else stax.Identity,
        activation,
    )
Example #5
0
def ConvBlock(kernel_size, filters, strides):
    """ResNet convolutional striding block."""
    ks = kernel_size
    filters1, filters2, filters3 = filters
    main = stax.serial(stax.Conv(filters1, (1, 1),
                                 strides), stax.BatchNorm(), stax.Relu,
                       stax.Conv(filters2, (ks, ks), padding='SAME'),
                       stax.BatchNorm(), stax.Relu,
                       stax.Conv(filters3, (1, 1)), stax.BatchNorm())
    shortcut = stax.serial(stax.Conv(filters3, (1, 1), strides),
                           stax.BatchNorm())
    return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut),
                       stax.FanInSum, stax.Relu)
Example #6
0
def wide_resnet_block(num_channels, strides=(1, 1), channel_mismatch=False):
    """Wide ResNet block."""
    pre = stax.serial(stax.BatchNorm(), stax.Relu)
    mid = stax.serial(
        pre, stax.Conv(num_channels, (3, 3), strides, padding='SAME'),
        stax.BatchNorm(), stax.Relu,
        stax.Conv(num_channels, (3, 3), strides=(1, 1), padding='SAME'))
    if channel_mismatch:
        cut = stax.serial(
            pre, stax.Conv(num_channels, (3, 3), strides, padding='SAME'))
    else:
        cut = stax.Identity
    return stax.serial(stax.FanOut(2), stax.parallel(mid, cut), stax.FanInSum)
Example #7
0
def Resnet50(hidden_size=64, num_output_classes=1001):
    """ResNet.

  Args:
    hidden_size: the size of the first hidden layer (multiplied later).
    num_output_classes: how many classes to distinguish.

  Returns:
    The ResNet model with the given layer and output sizes.
  """
    return stax.serial(
        stax.Conv(hidden_size, (7, 7), (2, 2),
                  'SAME'), stax.BatchNorm(), stax.Relu,
        stax.MaxPool((3, 3), strides=(2, 2)),
        ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)),
        IdentityBlock(3, [hidden_size, hidden_size]),
        IdentityBlock(3, [hidden_size, hidden_size]),
        ConvBlock(3,
                  [2 * hidden_size, 2 * hidden_size, 8 * hidden_size], (2, 2)),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]),
        ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size],
                  (2, 2)), IdentityBlock(3,
                                         [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size],
                  (2, 2)), IdentityBlock(3,
                                         [8 * hidden_size, 8 * hidden_size]),
        IdentityBlock(3, [8 * hidden_size, 8 * hidden_size]),
        stax.AvgPool((7, 7)), stax.Flatten, stax.Dense(num_output_classes),
        stax.LogSoftmax)
Example #8
0
def _batch_norm_internal(batchnorm, axis=(0, 1, 2)):
    """Layer constructor for a stax.BatchNorm layer with dummy kernel computation.
  Do not use kernels for architectures that include this function."""
    bn = stax.BatchNorm()
    init_fn, apply_fn = bn
    kernel_fn = lambda kernels: kernels
    return init_fn, apply_fn, kernel_fn
Example #9
0
    def __init__(self, num_classes=100, encoding=True):

        blocks = [
            stax.GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2),
                             'SAME'),
            stax.BatchNorm(), stax.Relu,
            stax.MaxPool((3, 3), strides=(2, 2)),
            self.ConvBlock(3, [64, 64, 256], strides=(1, 1)),
            self.IdentityBlock(3, [64, 64]),
            self.IdentityBlock(3, [64, 64]),
            self.ConvBlock(3, [128, 128, 512]),
            self.IdentityBlock(3, [128, 128]),
            self.IdentityBlock(3, [128, 128]),
            self.IdentityBlock(3, [128, 128]),
            self.ConvBlock(3, [256, 256, 1024]),
            self.IdentityBlock(3, [256, 256]),
            self.IdentityBlock(3, [256, 256]),
            self.IdentityBlock(3, [256, 256]),
            self.IdentityBlock(3, [256, 256]),
            self.IdentityBlock(3, [256, 256]),
            self.ConvBlock(3, [512, 512, 2048]),
            self.IdentityBlock(3, [512, 512]),
            self.IdentityBlock(3, [512, 512]),
            stax.AvgPool((7, 7))
        ]

        if not encoding:
            blocks.append(stax.Flatten)
            blocks.append(stax.Dense(num_classes))

        self.model = stax.serial(*blocks)
Example #10
0
def wide_resnet(n, k, num_classes):
    """Original WRN from paper and previous experiments."""
    return stax.serial(stax.Conv(16, (3, 3), padding='SAME'),
                       wide_resnet_group(n, 16 * k, strides=(1, 1)),
                       wide_resnet_group(n, 32 * k, strides=(2, 2)),
                       wide_resnet_group(n, 64 * k, strides=(2, 2)),
                       stax.BatchNorm(), stax.Relu, stax.AvgPool((8, 8)),
                       stax.Flatten, stax.Dense(num_classes))
Example #11
0
def create_model(nbin, nhidden, nlayer):
    layers = []
    for i in range(nlayer):
        layers.extend([
            stax.Dense(nhidden),
            stax.LeakyRelu,
            stax.BatchNorm(axis=(0, 1)),
        ])
    layers.extend([stax.Dense(nbin), stax.Softmax])
    return stax.serial(*layers)
Example #12
0
    def testBatchNormShapeNHWC(self):
        init_fun, apply_fun = stax.BatchNorm(axis=(0, 1, 2))
        input_shape = (4, 5, 6, 7)
        inputs = random_inputs(onp.random.RandomState(0), input_shape)

        out_shape, params = init_fun(input_shape)
        out = apply_fun(params, inputs)

        self.assertEqual(out_shape, input_shape)
        beta, gamma = params
        self.assertEqual(beta.shape, (7, ))
        self.assertEqual(gamma.shape, (7, ))
        self.assertEqual(out_shape, out.shape)
Example #13
0
  def testBatchNormNoScaleOrCenter(self):
    key = random.PRNGKey(0)
    axes = (0, 1, 2)
    init_fun, apply_fun = stax.BatchNorm(axis=axes, center=False, scale=False)
    input_shape = (4, 5, 6, 7)
    inputs = random_inputs(onp.random.RandomState(0), input_shape)

    out_shape, params = init_fun(key, input_shape)
    out = apply_fun(params, inputs)
    means = onp.mean(out, axis=(0, 1, 2))
    std_devs = onp.std(out, axis=(0, 1, 2))
    assert onp.allclose(means, onp.zeros_like(means), atol=1e-4)
    assert onp.allclose(std_devs, onp.ones_like(std_devs), atol=1e-4)
Example #14
0
    def testBatchNormShapeNCHW(self):
        # Regression test for https://github.com/google/jax/issues/461
        init_fun, apply_fun = stax.BatchNorm(axis=(0, 2, 3))
        input_shape = (4, 5, 6, 7)
        inputs = random_inputs(onp.random.RandomState(0), input_shape)

        out_shape, params = init_fun(input_shape)
        out = apply_fun(params, inputs)

        self.assertEqual(out_shape, input_shape)
        beta, gamma = params
        self.assertEqual(beta.shape, (5, ))
        self.assertEqual(gamma.shape, (5, ))
        self.assertEqual(out_shape, out.shape)
Example #15
0
def WideResnet(num_blocks=3, hidden_size=64, num_output_classes=10):
    """WideResnet from https://arxiv.org/pdf/1605.07146.pdf.

  Args:
    num_blocks: int, number of blocks in a group.
    hidden_size: the size of the first hidden layer (multiplied later).
    num_output_classes: int, number of classes to distinguish.

  Returns:
    The WideResnet model with given layer and output sizes.
  """
    return stax.serial(stax.Conv(hidden_size, (3, 3), padding='SAME'),
                       WideResnetGroup(num_blocks, hidden_size),
                       WideResnetGroup(num_blocks, hidden_size * 2, (2, 2)),
                       WideResnetGroup(num_blocks, hidden_size * 4, (2, 2)),
                       stax.BatchNorm(), stax.Relu,
                       stax.AvgPool((8, 8)), stax.Flatten,
                       stax.Dense(num_output_classes), stax.LogSoftmax)
Example #16
0
def dense_net(in_channels: int,
              out_channels: int,
              layers: tuple or list,
              batch_norm=False,
              activation='ReLU') -> StaxNet:
    activation = {
        'ReLU': stax.Relu,
        'Sigmoid': stax.Sigmoid,
        'tanh': stax.Tanh
    }[activation]
    stax_layers = []
    for neuron_count in layers:
        stax_layers.append(stax.Dense(neuron_count))
        stax_layers.append(activation)
        if batch_norm:
            stax_layers.append(stax.BatchNorm(axis=(0, )))
    stax_layers.append(stax.Dense(out_channels))
    net_init, net_apply = stax.serial(*stax_layers)
    net = StaxNet(net_init, net_apply, (-1, in_channels))
    net.initialize()
    return net