def IgnoreConv2D(out_dim, W_init=he_normal(), b_init=normal(), kernel=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): assert dilation == 1 and groups == 1 if not transpose: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) else: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) def apply_fun(params, inputs, **kwargs): x, t = inputs out = apply_fun_wrapped(params, x, **kwargs) return (out, t) return init_fun_wrapped, apply_fun_wrapped
def GeneralConv(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_gain=1.0, W_init=stax.randn(1.0), b_gain=0.0, b_init=stax.randn(1.0)): """Layer construction function for a general convolution layer. Uses jax.experimental.stax.GeneralConv as a base. """ lhs_spec, rhs_spec, out_spec = dimension_numbers one = (1, ) * len(filter_shape) strides = strides or one init_fun, _ = stax.GeneralConv(dimension_numbers, out_chan, filter_shape, strides, padding, W_init, b_init) def apply_fun(params, inputs, **kwargs): W, b = params norm = inputs.shape[lhs_spec.index('C')] norm *= functools.reduce(op.mul, filter_shape) norm = W_gain / np.sqrt(norm) return norm * lax.conv_general_dilated(inputs, W, strides, padding, one, one, dimension_numbers) + b_gain * b return init_fun, apply_fun
def __init__(self, num_classes=100, encoding=True): blocks = [ stax.GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'), stax.BatchNorm(), stax.Relu, stax.MaxPool((3, 3), strides=(2, 2)), self.ConvBlock(3, [64, 64, 256], strides=(1, 1)), self.IdentityBlock(3, [64, 64]), self.IdentityBlock(3, [64, 64]), self.ConvBlock(3, [128, 128, 512]), self.IdentityBlock(3, [128, 128]), self.IdentityBlock(3, [128, 128]), self.IdentityBlock(3, [128, 128]), self.ConvBlock(3, [256, 256, 1024]), self.IdentityBlock(3, [256, 256]), self.IdentityBlock(3, [256, 256]), self.IdentityBlock(3, [256, 256]), self.IdentityBlock(3, [256, 256]), self.IdentityBlock(3, [256, 256]), self.ConvBlock(3, [512, 512, 2048]), self.IdentityBlock(3, [512, 512]), self.IdentityBlock(3, [512, 512]), stax.AvgPool((7, 7)) ] if not encoding: blocks.append(stax.Flatten) blocks.append(stax.Dense(num_classes)) self.model = stax.serial(*blocks)
def ResidualBlock(out_channels, kernel_size, stride, padding, input_format): double_conv = stax.serial( stax.GeneralConv(input_format, out_channels, kernel_size, stride, padding), stax.Elu, ) return Module( *stax.serial( stax.FanOut(2), stax.parallel(double_conv, stax.Identity), stax.FanInSum ) )
def ResNet(hidden_channels, out_channels, depth): # time integration module backbone = stax.serial( stax.GeneralConv( ("NCDWH", "IDWHO", "NCDWH"), hidden_channels, (4, 3, 3), (1, 1, 1), "SAME" ), *[ ResidualBlock( hidden_channels, (4, 3, 3), (1, 1, 1), "SAME", ("NCDWH", "IDWHO", "NCDWH"), ) for _ in range(depth) ], stax.GeneralConv( ("NCDWH", "IDWHO", "NCDWH"), out_channels, (4, 3, 3), (1, 1, 1), "SAME" ), stax.GeneralConv(("NDCWH", "IDWHO", "NDCWH"), 3, (3, 3, 3), (1, 1, 1), "SAME"), ) # euler scheme return stax.serial(stax.FanOut(2), stax.parallel(stax.Identity, backbone), Euler())
def ConcatSquashConv2D(out_dim, W_init=he_normal(), b_init=normal(), kernel=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): assert dilation == 1 and groups == 1 if not transpose: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) else: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) def init_fun(rng, input_shape): k1, k2, k3, k4 = random.split(rng, 4) output_shape_conv, params_conv = init_fun_wrapped(k1, input_shape) W_hyper_gate, b_hyper_gate = W_init(k2, (1, out_dim)), b_init( k3, (out_dim, )) W_hyper_bias = W_init(k4, (1, out_dim)) return output_shape_conv, (params_conv, W_hyper_gate, b_hyper_gate, W_hyper_bias) def apply_fun(params, inputs, **kwargs): x, t = inputs params_conv, W_hyper_gate, b_hyper_gate, W_hyper_bias = params conv_out = apply_fun_wrapped(params_conv, x, **kwargs) gate_out = jax.nn.sigmoid( np.dot(t.view(1, 1), W_hyper_gate) + b_hyper_gate).view( 1, 1, 1, -1) bias_out = np.dot(t.view(1, 1), W_hyper_bias).view(1, 1, 1, -1) out = conv_out * gate_out + bias_out return (out, t) return init_fun, apply_fun
def ConcatCoordConv2D(out_dim, W_init=he_normal(), b_init=normal(), kernel=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): assert dilation == 1 and groups == 1 if not transpose: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) else: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) def init_fun(rng, input_shape): concat_input_shape = list(input_shape) # add time and coord channels; from 1 (torch) -> 0 concat_input_shape[-1] += 3 concat_input_shape = tuple(concat_input_shape) return init_fun_wrapped(rng, concat_input_shape) def apply_fun(params, inputs, **kwargs): x, t = inputs b, h, w, c = x.shape hh = np.arange(h).view(1, h, 1, 1).expand(b, h, w, 1) ww = np.arange(w).view(1, 1, w, 1).expand(b, h, w, 1) tt = t.view(1, 1, 1, 1).expand(b, h, w, 1) x_aug = np.concatenate([x, hh, ww, tt], axis=-1) out = apply_fun_wrapped(params, x_aug, **kwargs) return (out, t) return init_fun, apply_fun
def ConcatConv2D_v2(out_dim, W_init=he_normal(), b_init=normal(), kernel=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): assert dilation == 1 and groups == 1 if not transpose: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) else: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) def init_fun(rng, input_shape): k1, k2 = random.split(rng) output_shape_conv, params_conv = init_fun_wrapped(k1, input_shape) W_hyper_bias = W_init(k2, (1, out_dim)) return output_shape_conv, (params_conv, W_hyper_bias) def apply_fun(params, inputs, **kwargs): x, t = inputs params_conv, W_hyper_bias = params out = apply_fun_wrapped(params_conv, x, **kwargs) + np.dot( t.view(1, 1), W_hyper_bias).view( 1, 1, 1, -1) # if ncwh stead of nhwc: .view(1, -1, 1, 1) return (out, t) return init_fun, apply_fun
def BlendConv2D(out_dim, W_init=he_normal(), b_init=normal(), kernel=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): assert dilation == 1 and groups == 1 if not transpose: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) else: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) def init_fun(rng, input_shape): k1, k2 = random.split(rng) output_shape, params_f = init_fun_wrapped(k1, input_shape) _, params_g = init_fun_wrapped(k2, input_shape) return output_shape, (params_f, params_g) def apply_fun(params, inputs, **kwargs): x, t = inputs params_f, params_g = params f = apply_fun_wrapped(params_f, x) g = apply_fun_wrapped(params_g, x) out = f + (g - f) * t return (out, t) return init_fun, apply_fun
def ConcatConv2D(out_dim, W_init=he_normal(), b_init=normal(), kernel=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): assert dilation == 1 and groups == 1 if not transpose: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) else: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) def init_fun(rng, input_shape): # note, input shapes only take x concat_input_shape = list(input_shape) concat_input_shape[-1] += 1 # add time channel dim concat_input_shape = tuple(concat_input_shape) return init_fun_wrapped(rng, concat_input_shape) def apply_fun(params, inputs, **kwargs): x, t = inputs tt = np.ones_like(x[:, :, :, :1]) * t xtt = np.concatenate([x, tt], axis=-1) out = apply_fun_wrapped(params, xtt, **kwargs) return (out, t) return init_fun, apply_fun
def _GeneralConv(dimension_numbers, out_chan, filter_shape, strides=None, padding=Padding.VALID.value, W_std=1.0, W_init=_randn(1.0), b_std=0.0, b_init=_randn(1.0)): """Layer construction function for a general convolution layer. Based on `jax.experimental.stax.GeneralConv`. Has a similar API apart from: Args: padding: in addition to `VALID` and `SAME' padding, supports `CIRCULAR`, not available in `jax.experimental.stax.GeneralConv`. """ if dimension_numbers != _CONV_DIMENSION_NUMBERS: raise NotImplementedError('Dimension numbers %s not implemented.' % str(dimension_numbers)) lhs_spec, rhs_spec, out_spec = dimension_numbers one = (1, ) * len(filter_shape) strides = strides or one padding = Padding(padding) init_padding = padding if padding == Padding.CIRCULAR: init_padding = Padding.SAME init_fun, _ = stax.GeneralConv(dimension_numbers, out_chan, filter_shape, strides, init_padding.value, W_init, b_init) def apply_fun(params, inputs, **kwargs): W, b = params norm = inputs.shape[lhs_spec.index('C')] norm *= np.prod(filter_shape) apply_padding = padding if padding == Padding.CIRCULAR: apply_padding = Padding.VALID inputs = _same_pad_for_filter_shape(inputs, filter_shape, strides, (1, 2), 'wrap') norm = W_std / np.sqrt(norm) return norm * lax.conv_general_dilated( inputs, W, strides, apply_padding.value, dimension_numbers=dimension_numbers) + b_std * b def ker_fun(kernels): """Compute the transformed kernels after a conv layer.""" # var1: batch_1 * height * width # var2: batch_2 * height * width # nngp, ntk: batch_1 * batch_2 * height * height * width * width (pooling) # or batch_1 * batch_2 * height * width (flattening) var1, nngp, var2, ntk, _, is_height_width = kernels if nngp.ndim == 4: def conv_var(x): x = _conv_var_3d(x, filter_shape, strides, padding) x = _affine(x, W_std, b_std) return x def conv_nngp(x): if _is_array(x): x = _conv_nngp_4d(x, filter_shape, strides, padding) x = _affine(x, W_std, b_std) return x elif nngp.ndim == 6: if not is_height_width: filter_shape_nngp = filter_shape[::-1] strides_nngp = strides[::-1] else: filter_shape_nngp = filter_shape strides_nngp = strides def conv_var(x): x = _conv_var_3d(x, filter_shape_nngp, strides_nngp, padding) if x is not None: x = np.transpose(x, (0, 2, 1)) x = _affine(x, W_std, b_std) return x def conv_nngp(x): if _is_array(x): x = _conv_nngp_6d_double_conv(x, filter_shape_nngp, strides_nngp, padding) x = _affine(x, W_std, b_std) return x is_height_width = not is_height_width else: raise ValueError('`nngp` array must be either 4d or 6d, got %d.' % nngp.ndim) var1 = conv_var(var1) var2 = conv_var(var2) nngp = conv_nngp(nngp) ntk = conv_nngp(ntk) + nngp - b_std**2 if ntk is not None else ntk return Kernel(var1, nngp, var2, ntk, True, is_height_width) return init_fun, apply_fun, ker_fun
def TaylorConv(out_chan, filter_shape, strides=None, padding=Padding.VALID.name, W_std=1.0, W_init=_randn(1.0), b_std=0.0, b_init=_randn(1.0), order=2): """Layer construction function for a convolution layer with Taylorized parameterization. Based on `jax.experimental.stax.GeneralConv`. Has a similar API apart from: Args: padding: in addition to `VALID` and `SAME' padding, supports `CIRCULAR`, not available in `jax.experimental.stax.GeneralConv`. """ assert(isinstance(order, int) and order >= 1) dimension_numbers = ('NHWC', 'HWIO', 'NHWC') lhs_spec, rhs_spec, out_spec = dimension_numbers one = (1,) * len(filter_shape) strides = strides or one padding = Padding(padding) init_padding = padding if padding == Padding.CIRCULAR: init_padding = Padding.SAME def input_total_dim(input_shape): return input_shape[lhs_spec.index('C')] * np.prod(filter_shape) ntk_init_fn, _ = jax_stax.GeneralConv(dimension_numbers, out_chan, filter_shape, strides, init_padding.name, W_init, b_init) def taylor_init_fn(rng, input_shape): output_shape, (W, b) = ntk_init_fn(rng, input_shape) norm = W_std / (input_total_dim(input_shape) ** ((order-1)/(2*order+2))) return output_shape, (W * norm, b * b_std) def apply_fn(params, inputs, **kwargs): W, b = params norm = W_std / (input_total_dim(inputs.shape) ** (1/(order+1))) b_rescale = b_std apply_padding = padding if padding == Padding.CIRCULAR: apply_padding = Padding.VALID non_spatial_axes = (dimension_numbers[0].index('N'), dimension_numbers[0].index('C')) spatial_axes = tuple(i for i in range(inputs.ndim) if i not in non_spatial_axes) inputs = _same_pad_for_filter_shape(inputs, filter_shape, strides, spatial_axes, 'wrap') return norm * lax.conv_general_dilated( inputs, W, strides, apply_padding.name, dimension_numbers=dimension_numbers) + b_rescale * b return taylor_init_fn, apply_fn
def conv_params(input_format, out_channels, kernel_size, stride, padding): return stax.GeneralConv(input_format, out_channels, kernel_size, stride, padding)[0]