示例#1
0
def AvgPool(window_shape, strides=None, padding=Padding.VALID.value):
    """Layer construction function for a 2D average pooling layer.

  Based on `jax.experimental.stax.AvgPool`. Has a similar API apart from:

  Args:
    padding: in addition to `VALID` and `SAME' padding, supports `CIRCULAR`,
      not available in `jax.experimental.stax.GeneralConv`.
  """
    strides = strides or (1, ) * len(window_shape)
    padding = Padding(padding)

    if padding == Padding.CIRCULAR:
        init_fun, _ = stax.AvgPool(window_shape, strides, Padding.SAME.value)
        _, apply_fun_0 = stax.AvgPool(window_shape, strides,
                                      Padding.VALID.value)

        def apply_fun(params, inputs, **kwargs):
            inputs = _same_pad_for_filter_shape(inputs, window_shape, strides,
                                                (1, 2), 'wrap')
            res = apply_fun_0(params, inputs, **kwargs)
            return res
    else:
        init_fun, apply_fun = stax.AvgPool(window_shape, strides,
                                           padding.value)

    def ker_fun(kernels):
        """Kernel transformation."""
        var1, nngp, var2, ntk, is_gaussian, is_height_width = kernels

        if not is_height_width:
            window_shape_nngp = window_shape[::-1]
            strides_nngp = strides[::-1]
        else:
            window_shape_nngp = window_shape
            strides_nngp = strides

        nngp = _average_pool_nngp_6d(nngp, window_shape_nngp, strides_nngp,
                                     padding)
        ntk = _average_pool_nngp_6d(ntk, window_shape_nngp, strides_nngp,
                                    padding)

        if var2 is None:
            var1 = _diagonal_nngp_6d(nngp)
        else:
            # TODO(romann)
            warnings.warn(
                'Pooling for different inputs `x1` and `x2` is not '
                'implemented and will only work if there are no '
                'nonlinearities in the network anywhere after the pooling '
                'layer. `var1` and `var2` will have wrong values. '
                'This will be fixed soon.')

        return Kernel(var1, nngp, var2, ntk, is_gaussian, is_height_width)

    setattr(ker_fun, _USE_POOLING, True)
    return init_fun, apply_fun, ker_fun
示例#2
0
def Resnet50(hidden_size=64, num_output_classes=1001):
    """ResNet.

  Args:
    hidden_size: the size of the first hidden layer (multiplied later).
    num_output_classes: how many classes to distinguish.

  Returns:
    The ResNet model with the given layer and output sizes.
  """
    return stax.serial(
        stax.Conv(hidden_size, (7, 7), (2, 2),
                  'SAME'), stax.BatchNorm(), stax.Relu,
        stax.MaxPool((3, 3), strides=(2, 2)),
        ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)),
        IdentityBlock(3, [hidden_size, hidden_size]),
        IdentityBlock(3, [hidden_size, hidden_size]),
        ConvBlock(3,
                  [2 * hidden_size, 2 * hidden_size, 8 * hidden_size], (2, 2)),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]),
        ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size],
                  (2, 2)), IdentityBlock(3,
                                         [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size],
                  (2, 2)), IdentityBlock(3,
                                         [8 * hidden_size, 8 * hidden_size]),
        IdentityBlock(3, [8 * hidden_size, 8 * hidden_size]),
        stax.AvgPool((7, 7)), stax.Flatten, stax.Dense(num_output_classes),
        stax.LogSoftmax)
示例#3
0
    def __init__(self, num_classes=100, encoding=True):

        blocks = [
            stax.GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2),
                             'SAME'),
            stax.BatchNorm(), stax.Relu,
            stax.MaxPool((3, 3), strides=(2, 2)),
            self.ConvBlock(3, [64, 64, 256], strides=(1, 1)),
            self.IdentityBlock(3, [64, 64]),
            self.IdentityBlock(3, [64, 64]),
            self.ConvBlock(3, [128, 128, 512]),
            self.IdentityBlock(3, [128, 128]),
            self.IdentityBlock(3, [128, 128]),
            self.IdentityBlock(3, [128, 128]),
            self.ConvBlock(3, [256, 256, 1024]),
            self.IdentityBlock(3, [256, 256]),
            self.IdentityBlock(3, [256, 256]),
            self.IdentityBlock(3, [256, 256]),
            self.IdentityBlock(3, [256, 256]),
            self.IdentityBlock(3, [256, 256]),
            self.ConvBlock(3, [512, 512, 2048]),
            self.IdentityBlock(3, [512, 512]),
            self.IdentityBlock(3, [512, 512]),
            stax.AvgPool((7, 7))
        ]

        if not encoding:
            blocks.append(stax.Flatten)
            blocks.append(stax.Dense(num_classes))

        self.model = stax.serial(*blocks)
示例#4
0
def wide_resnet(n, k, num_classes):
    """Original WRN from paper and previous experiments."""
    return stax.serial(stax.Conv(16, (3, 3), padding='SAME'),
                       wide_resnet_group(n, 16 * k, strides=(1, 1)),
                       wide_resnet_group(n, 32 * k, strides=(2, 2)),
                       wide_resnet_group(n, 64 * k, strides=(2, 2)),
                       stax.BatchNorm(), stax.Relu, stax.AvgPool((8, 8)),
                       stax.Flatten, stax.Dense(num_classes))
示例#5
0
def WideResnet(num_blocks=3, hidden_size=64, num_output_classes=10):
    """WideResnet from https://arxiv.org/pdf/1605.07146.pdf.

  Args:
    num_blocks: int, number of blocks in a group.
    hidden_size: the size of the first hidden layer (multiplied later).
    num_output_classes: int, number of classes to distinguish.

  Returns:
    The WideResnet model with given layer and output sizes.
  """
    return stax.serial(stax.Conv(hidden_size, (3, 3), padding='SAME'),
                       WideResnetGroup(num_blocks, hidden_size),
                       WideResnetGroup(num_blocks, hidden_size * 2, (2, 2)),
                       WideResnetGroup(num_blocks, hidden_size * 4, (2, 2)),
                       stax.BatchNorm(), stax.Relu,
                       stax.AvgPool((8, 8)), stax.Flatten,
                       stax.Dense(num_output_classes), stax.LogSoftmax)