def IgnoreConv2D(out_dim,
                 W_init=he_normal(),
                 b_init=normal(),
                 kernel=3,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 transpose=False):
    assert dilation == 1 and groups == 1
    if not transpose:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)
    else:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)

    def apply_fun(params, inputs, **kwargs):
        x, t = inputs
        out = apply_fun_wrapped(params, x, **kwargs)
        return (out, t)

    return init_fun_wrapped, apply_fun_wrapped
Esempio n. 2
0
def GeneralConv(dimension_numbers,
                out_chan,
                filter_shape,
                strides=None,
                padding='VALID',
                W_gain=1.0,
                W_init=stax.randn(1.0),
                b_gain=0.0,
                b_init=stax.randn(1.0)):
    """Layer construction function for a general convolution layer.

  Uses jax.experimental.stax.GeneralConv as a base.
  """
    lhs_spec, rhs_spec, out_spec = dimension_numbers
    one = (1, ) * len(filter_shape)
    strides = strides or one
    init_fun, _ = stax.GeneralConv(dimension_numbers, out_chan, filter_shape,
                                   strides, padding, W_init, b_init)

    def apply_fun(params, inputs, **kwargs):
        W, b = params
        norm = inputs.shape[lhs_spec.index('C')]
        norm *= functools.reduce(op.mul, filter_shape)
        norm = W_gain / np.sqrt(norm)
        return norm * lax.conv_general_dilated(inputs, W, strides, padding,
                                               one, one,
                                               dimension_numbers) + b_gain * b

    return init_fun, apply_fun
Esempio n. 3
0
    def __init__(self, num_classes=100, encoding=True):

        blocks = [
            stax.GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2),
                             'SAME'),
            stax.BatchNorm(), stax.Relu,
            stax.MaxPool((3, 3), strides=(2, 2)),
            self.ConvBlock(3, [64, 64, 256], strides=(1, 1)),
            self.IdentityBlock(3, [64, 64]),
            self.IdentityBlock(3, [64, 64]),
            self.ConvBlock(3, [128, 128, 512]),
            self.IdentityBlock(3, [128, 128]),
            self.IdentityBlock(3, [128, 128]),
            self.IdentityBlock(3, [128, 128]),
            self.ConvBlock(3, [256, 256, 1024]),
            self.IdentityBlock(3, [256, 256]),
            self.IdentityBlock(3, [256, 256]),
            self.IdentityBlock(3, [256, 256]),
            self.IdentityBlock(3, [256, 256]),
            self.IdentityBlock(3, [256, 256]),
            self.ConvBlock(3, [512, 512, 2048]),
            self.IdentityBlock(3, [512, 512]),
            self.IdentityBlock(3, [512, 512]),
            stax.AvgPool((7, 7))
        ]

        if not encoding:
            blocks.append(stax.Flatten)
            blocks.append(stax.Dense(num_classes))

        self.model = stax.serial(*blocks)
Esempio n. 4
0
def ResidualBlock(out_channels, kernel_size, stride, padding, input_format):
    double_conv = stax.serial(
        stax.GeneralConv(input_format, out_channels, kernel_size, stride, padding),
        stax.Elu,
    )
    return Module(
        *stax.serial(
            stax.FanOut(2), stax.parallel(double_conv, stax.Identity), stax.FanInSum
        )
    )
Esempio n. 5
0
def ResNet(hidden_channels, out_channels, depth):
    # time integration module
    backbone = stax.serial(
        stax.GeneralConv(
            ("NCDWH", "IDWHO", "NCDWH"), hidden_channels, (4, 3, 3), (1, 1, 1), "SAME"
        ),
        *[
            ResidualBlock(
                hidden_channels,
                (4, 3, 3),
                (1, 1, 1),
                "SAME",
                ("NCDWH", "IDWHO", "NCDWH"),
            )
            for _ in range(depth)
        ],
        stax.GeneralConv(
            ("NCDWH", "IDWHO", "NCDWH"), out_channels, (4, 3, 3), (1, 1, 1), "SAME"
        ),
        stax.GeneralConv(("NDCWH", "IDWHO", "NDCWH"), 3, (3, 3, 3), (1, 1, 1), "SAME"),
    )

    #  euler scheme
    return stax.serial(stax.FanOut(2), stax.parallel(stax.Identity, backbone), Euler())
def ConcatSquashConv2D(out_dim,
                       W_init=he_normal(),
                       b_init=normal(),
                       kernel=3,
                       stride=1,
                       padding=0,
                       dilation=1,
                       groups=1,
                       bias=True,
                       transpose=False):
    assert dilation == 1 and groups == 1
    if not transpose:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)
    else:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)

    def init_fun(rng, input_shape):
        k1, k2, k3, k4 = random.split(rng, 4)
        output_shape_conv, params_conv = init_fun_wrapped(k1, input_shape)
        W_hyper_gate, b_hyper_gate = W_init(k2, (1, out_dim)), b_init(
            k3, (out_dim, ))
        W_hyper_bias = W_init(k4, (1, out_dim))
        return output_shape_conv, (params_conv, W_hyper_gate, b_hyper_gate,
                                   W_hyper_bias)

    def apply_fun(params, inputs, **kwargs):
        x, t = inputs
        params_conv, W_hyper_gate, b_hyper_gate, W_hyper_bias = params
        conv_out = apply_fun_wrapped(params_conv, x, **kwargs)
        gate_out = jax.nn.sigmoid(
            np.dot(t.view(1, 1), W_hyper_gate) + b_hyper_gate).view(
                1, 1, 1, -1)
        bias_out = np.dot(t.view(1, 1), W_hyper_bias).view(1, 1, 1, -1)
        out = conv_out * gate_out + bias_out
        return (out, t)

    return init_fun, apply_fun
def ConcatCoordConv2D(out_dim,
                      W_init=he_normal(),
                      b_init=normal(),
                      kernel=3,
                      stride=1,
                      padding=0,
                      dilation=1,
                      groups=1,
                      bias=True,
                      transpose=False):
    assert dilation == 1 and groups == 1
    if not transpose:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)
    else:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)

    def init_fun(rng, input_shape):
        concat_input_shape = list(input_shape)
        # add time and coord channels; from 1 (torch) -> 0
        concat_input_shape[-1] += 3
        concat_input_shape = tuple(concat_input_shape)
        return init_fun_wrapped(rng, concat_input_shape)

    def apply_fun(params, inputs, **kwargs):
        x, t = inputs
        b, h, w, c = x.shape
        hh = np.arange(h).view(1, h, 1, 1).expand(b, h, w, 1)
        ww = np.arange(w).view(1, 1, w, 1).expand(b, h, w, 1)
        tt = t.view(1, 1, 1, 1).expand(b, h, w, 1)
        x_aug = np.concatenate([x, hh, ww, tt], axis=-1)
        out = apply_fun_wrapped(params, x_aug, **kwargs)
        return (out, t)

    return init_fun, apply_fun
def ConcatConv2D_v2(out_dim,
                    W_init=he_normal(),
                    b_init=normal(),
                    kernel=3,
                    stride=1,
                    padding=0,
                    dilation=1,
                    groups=1,
                    bias=True,
                    transpose=False):
    assert dilation == 1 and groups == 1
    if not transpose:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)
    else:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)

    def init_fun(rng, input_shape):
        k1, k2 = random.split(rng)
        output_shape_conv, params_conv = init_fun_wrapped(k1, input_shape)
        W_hyper_bias = W_init(k2, (1, out_dim))

        return output_shape_conv, (params_conv, W_hyper_bias)

    def apply_fun(params, inputs, **kwargs):
        x, t = inputs
        params_conv, W_hyper_bias = params
        out = apply_fun_wrapped(params_conv, x, **kwargs) + np.dot(
            t.view(1, 1), W_hyper_bias).view(
                1, 1, 1, -1)  # if ncwh stead of nhwc: .view(1, -1, 1, 1)
        return (out, t)

    return init_fun, apply_fun
def BlendConv2D(out_dim,
                W_init=he_normal(),
                b_init=normal(),
                kernel=3,
                stride=1,
                padding=0,
                dilation=1,
                groups=1,
                bias=True,
                transpose=False):
    assert dilation == 1 and groups == 1
    if not transpose:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)
    else:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)

    def init_fun(rng, input_shape):
        k1, k2 = random.split(rng)
        output_shape, params_f = init_fun_wrapped(k1, input_shape)
        _, params_g = init_fun_wrapped(k2, input_shape)
        return output_shape, (params_f, params_g)

    def apply_fun(params, inputs, **kwargs):
        x, t = inputs
        params_f, params_g = params
        f = apply_fun_wrapped(params_f, x)
        g = apply_fun_wrapped(params_g, x)
        out = f + (g - f) * t
        return (out, t)

    return init_fun, apply_fun
Esempio n. 10
0
def ConcatConv2D(out_dim,
                 W_init=he_normal(),
                 b_init=normal(),
                 kernel=3,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 transpose=False):
    assert dilation == 1 and groups == 1
    if not transpose:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)
    else:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)

    def init_fun(rng, input_shape):  # note, input shapes only take x
        concat_input_shape = list(input_shape)
        concat_input_shape[-1] += 1  # add time channel dim
        concat_input_shape = tuple(concat_input_shape)
        return init_fun_wrapped(rng, concat_input_shape)

    def apply_fun(params, inputs, **kwargs):
        x, t = inputs
        tt = np.ones_like(x[:, :, :, :1]) * t
        xtt = np.concatenate([x, tt], axis=-1)
        out = apply_fun_wrapped(params, xtt, **kwargs)
        return (out, t)

    return init_fun, apply_fun
Esempio n. 11
0
def _GeneralConv(dimension_numbers,
                 out_chan,
                 filter_shape,
                 strides=None,
                 padding=Padding.VALID.value,
                 W_std=1.0,
                 W_init=_randn(1.0),
                 b_std=0.0,
                 b_init=_randn(1.0)):
    """Layer construction function for a general convolution layer.

  Based on `jax.experimental.stax.GeneralConv`. Has a similar API apart from:

  Args:
    padding: in addition to `VALID` and `SAME' padding, supports `CIRCULAR`,
      not available in `jax.experimental.stax.GeneralConv`.
  """
    if dimension_numbers != _CONV_DIMENSION_NUMBERS:
        raise NotImplementedError('Dimension numbers %s not implemented.' %
                                  str(dimension_numbers))

    lhs_spec, rhs_spec, out_spec = dimension_numbers

    one = (1, ) * len(filter_shape)
    strides = strides or one

    padding = Padding(padding)
    init_padding = padding
    if padding == Padding.CIRCULAR:
        init_padding = Padding.SAME

    init_fun, _ = stax.GeneralConv(dimension_numbers, out_chan, filter_shape,
                                   strides, init_padding.value, W_init, b_init)

    def apply_fun(params, inputs, **kwargs):
        W, b = params
        norm = inputs.shape[lhs_spec.index('C')]
        norm *= np.prod(filter_shape)
        apply_padding = padding
        if padding == Padding.CIRCULAR:
            apply_padding = Padding.VALID
            inputs = _same_pad_for_filter_shape(inputs, filter_shape, strides,
                                                (1, 2), 'wrap')
        norm = W_std / np.sqrt(norm)

        return norm * lax.conv_general_dilated(
            inputs,
            W,
            strides,
            apply_padding.value,
            dimension_numbers=dimension_numbers) + b_std * b

    def ker_fun(kernels):
        """Compute the transformed kernels after a conv layer."""
        # var1: batch_1 * height * width
        # var2: batch_2 * height * width
        # nngp, ntk: batch_1 * batch_2 * height * height * width * width (pooling)
        #  or batch_1 * batch_2 * height * width (flattening)
        var1, nngp, var2, ntk, _, is_height_width = kernels

        if nngp.ndim == 4:

            def conv_var(x):
                x = _conv_var_3d(x, filter_shape, strides, padding)
                x = _affine(x, W_std, b_std)
                return x

            def conv_nngp(x):
                if _is_array(x):
                    x = _conv_nngp_4d(x, filter_shape, strides, padding)
                x = _affine(x, W_std, b_std)
                return x

        elif nngp.ndim == 6:
            if not is_height_width:
                filter_shape_nngp = filter_shape[::-1]
                strides_nngp = strides[::-1]
            else:
                filter_shape_nngp = filter_shape
                strides_nngp = strides

            def conv_var(x):
                x = _conv_var_3d(x, filter_shape_nngp, strides_nngp, padding)
                if x is not None:
                    x = np.transpose(x, (0, 2, 1))
                x = _affine(x, W_std, b_std)
                return x

            def conv_nngp(x):
                if _is_array(x):
                    x = _conv_nngp_6d_double_conv(x, filter_shape_nngp,
                                                  strides_nngp, padding)
                x = _affine(x, W_std, b_std)
                return x

            is_height_width = not is_height_width

        else:
            raise ValueError('`nngp` array must be either 4d or 6d, got %d.' %
                             nngp.ndim)

        var1 = conv_var(var1)
        var2 = conv_var(var2)
        nngp = conv_nngp(nngp)
        ntk = conv_nngp(ntk) + nngp - b_std**2 if ntk is not None else ntk
        return Kernel(var1, nngp, var2, ntk, True, is_height_width)

    return init_fun, apply_fun, ker_fun
Esempio n. 12
0
def TaylorConv(out_chan,
               filter_shape,
               strides=None,
               padding=Padding.VALID.name,
               W_std=1.0,
               W_init=_randn(1.0),
               b_std=0.0,
               b_init=_randn(1.0),
               order=2):
    """Layer construction function for a convolution layer with Taylorized parameterization.
    Based on `jax.experimental.stax.GeneralConv`. Has a similar API apart from:
    Args:
        padding: in addition to `VALID` and `SAME' padding, supports `CIRCULAR`, not
        available in `jax.experimental.stax.GeneralConv`.
    """
    assert(isinstance(order, int) and order >= 1)
    dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
    lhs_spec, rhs_spec, out_spec = dimension_numbers

    one = (1,) * len(filter_shape)
    strides = strides or one

    padding = Padding(padding)
    init_padding = padding
    if padding == Padding.CIRCULAR:
        init_padding = Padding.SAME

    def input_total_dim(input_shape):
        return input_shape[lhs_spec.index('C')] * np.prod(filter_shape)

    ntk_init_fn, _ = jax_stax.GeneralConv(dimension_numbers, out_chan, filter_shape,
                                          strides, init_padding.name, W_init, b_init)

    def taylor_init_fn(rng, input_shape):
        output_shape, (W, b) = ntk_init_fn(rng, input_shape)
        norm = W_std / (input_total_dim(input_shape) ** ((order-1)/(2*order+2)))
        return output_shape, (W * norm, b * b_std)

    def apply_fn(params, inputs, **kwargs):
        W, b = params
        norm = W_std / (input_total_dim(inputs.shape) ** (1/(order+1)))
        b_rescale = b_std

        apply_padding = padding
        if padding == Padding.CIRCULAR:
            apply_padding = Padding.VALID
            non_spatial_axes = (dimension_numbers[0].index('N'),
                                dimension_numbers[0].index('C'))
            spatial_axes = tuple(i for i in range(inputs.ndim)
                                 if i not in non_spatial_axes)
            inputs = _same_pad_for_filter_shape(inputs, filter_shape, strides,
                                                spatial_axes, 'wrap')

        return norm * lax.conv_general_dilated(
            inputs,
            W,
            strides,
            apply_padding.name,
            dimension_numbers=dimension_numbers) + b_rescale * b

    return taylor_init_fn, apply_fn
Esempio n. 13
0
def conv_params(input_format, out_channels, kernel_size, stride, padding):
    return stax.GeneralConv(input_format, out_channels, kernel_size, stride, padding)[0]