def IgnoreConv2D(out_dim, W_init=he_normal(), b_init=normal(), kernel=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): assert dilation == 1 and groups == 1 if not transpose: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) else: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) def apply_fun(params, inputs, **kwargs): x, t = inputs out = apply_fun_wrapped(params, x, **kwargs) return (out, t) return init_fun_wrapped, apply_fun_wrapped
def ConcatSquashConv2D(out_dim, W_init=he_normal(), b_init=normal(), kernel=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): assert dilation == 1 and groups == 1 if not transpose: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) else: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) def init_fun(rng, input_shape): k1, k2, k3, k4 = random.split(rng, 4) output_shape_conv, params_conv = init_fun_wrapped(k1, input_shape) W_hyper_gate, b_hyper_gate = W_init(k2, (1, out_dim)), b_init( k3, (out_dim, )) W_hyper_bias = W_init(k4, (1, out_dim)) return output_shape_conv, (params_conv, W_hyper_gate, b_hyper_gate, W_hyper_bias) def apply_fun(params, inputs, **kwargs): x, t = inputs params_conv, W_hyper_gate, b_hyper_gate, W_hyper_bias = params conv_out = apply_fun_wrapped(params_conv, x, **kwargs) gate_out = jax.nn.sigmoid( np.dot(t.view(1, 1), W_hyper_gate) + b_hyper_gate).view( 1, 1, 1, -1) bias_out = np.dot(t.view(1, 1), W_hyper_bias).view(1, 1, 1, -1) out = conv_out * gate_out + bias_out return (out, t) return init_fun, apply_fun
def ConcatCoordConv2D(out_dim, W_init=he_normal(), b_init=normal(), kernel=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): assert dilation == 1 and groups == 1 if not transpose: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) else: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) def init_fun(rng, input_shape): concat_input_shape = list(input_shape) # add time and coord channels; from 1 (torch) -> 0 concat_input_shape[-1] += 3 concat_input_shape = tuple(concat_input_shape) return init_fun_wrapped(rng, concat_input_shape) def apply_fun(params, inputs, **kwargs): x, t = inputs b, h, w, c = x.shape hh = np.arange(h).view(1, h, 1, 1).expand(b, h, w, 1) ww = np.arange(w).view(1, 1, w, 1).expand(b, h, w, 1) tt = t.view(1, 1, 1, 1).expand(b, h, w, 1) x_aug = np.concatenate([x, hh, ww, tt], axis=-1) out = apply_fun_wrapped(params, x_aug, **kwargs) return (out, t) return init_fun, apply_fun
def ConcatConv2D_v2(out_dim, W_init=he_normal(), b_init=normal(), kernel=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): assert dilation == 1 and groups == 1 if not transpose: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) else: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) def init_fun(rng, input_shape): k1, k2 = random.split(rng) output_shape_conv, params_conv = init_fun_wrapped(k1, input_shape) W_hyper_bias = W_init(k2, (1, out_dim)) return output_shape_conv, (params_conv, W_hyper_bias) def apply_fun(params, inputs, **kwargs): x, t = inputs params_conv, W_hyper_bias = params out = apply_fun_wrapped(params_conv, x, **kwargs) + np.dot( t.view(1, 1), W_hyper_bias).view( 1, 1, 1, -1) # if ncwh stead of nhwc: .view(1, -1, 1, 1) return (out, t) return init_fun, apply_fun
def BlendConv2D(out_dim, W_init=he_normal(), b_init=normal(), kernel=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): assert dilation == 1 and groups == 1 if not transpose: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) else: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) def init_fun(rng, input_shape): k1, k2 = random.split(rng) output_shape, params_f = init_fun_wrapped(k1, input_shape) _, params_g = init_fun_wrapped(k2, input_shape) return output_shape, (params_f, params_g) def apply_fun(params, inputs, **kwargs): x, t = inputs params_f, params_g = params f = apply_fun_wrapped(params_f, x) g = apply_fun_wrapped(params_g, x) out = f + (g - f) * t return (out, t) return init_fun, apply_fun
def ConcatConv2D(out_dim, W_init=he_normal(), b_init=normal(), kernel=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): assert dilation == 1 and groups == 1 if not transpose: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) else: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) def init_fun(rng, input_shape): # note, input shapes only take x concat_input_shape = list(input_shape) concat_input_shape[-1] += 1 # add time channel dim concat_input_shape = tuple(concat_input_shape) return init_fun_wrapped(rng, concat_input_shape) def apply_fun(params, inputs, **kwargs): x, t = inputs tt = np.ones_like(x[:, :, :, :1]) * t xtt = np.concatenate([x, tt], axis=-1) out = apply_fun_wrapped(params, xtt, **kwargs) return (out, t) return init_fun, apply_fun