Ejemplo n.º 1
0
def IgnoreConv2D(out_dim,
                 W_init=he_normal(),
                 b_init=normal(),
                 kernel=3,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 transpose=False):
    assert dilation == 1 and groups == 1
    if not transpose:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)
    else:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)

    def apply_fun(params, inputs, **kwargs):
        x, t = inputs
        out = apply_fun_wrapped(params, x, **kwargs)
        return (out, t)

    return init_fun_wrapped, apply_fun_wrapped
Ejemplo n.º 2
0
def ConcatSquashConv2D(out_dim,
                       W_init=he_normal(),
                       b_init=normal(),
                       kernel=3,
                       stride=1,
                       padding=0,
                       dilation=1,
                       groups=1,
                       bias=True,
                       transpose=False):
    assert dilation == 1 and groups == 1
    if not transpose:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)
    else:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)

    def init_fun(rng, input_shape):
        k1, k2, k3, k4 = random.split(rng, 4)
        output_shape_conv, params_conv = init_fun_wrapped(k1, input_shape)
        W_hyper_gate, b_hyper_gate = W_init(k2, (1, out_dim)), b_init(
            k3, (out_dim, ))
        W_hyper_bias = W_init(k4, (1, out_dim))
        return output_shape_conv, (params_conv, W_hyper_gate, b_hyper_gate,
                                   W_hyper_bias)

    def apply_fun(params, inputs, **kwargs):
        x, t = inputs
        params_conv, W_hyper_gate, b_hyper_gate, W_hyper_bias = params
        conv_out = apply_fun_wrapped(params_conv, x, **kwargs)
        gate_out = jax.nn.sigmoid(
            np.dot(t.view(1, 1), W_hyper_gate) + b_hyper_gate).view(
                1, 1, 1, -1)
        bias_out = np.dot(t.view(1, 1), W_hyper_bias).view(1, 1, 1, -1)
        out = conv_out * gate_out + bias_out
        return (out, t)

    return init_fun, apply_fun
Ejemplo n.º 3
0
def ConcatCoordConv2D(out_dim,
                      W_init=he_normal(),
                      b_init=normal(),
                      kernel=3,
                      stride=1,
                      padding=0,
                      dilation=1,
                      groups=1,
                      bias=True,
                      transpose=False):
    assert dilation == 1 and groups == 1
    if not transpose:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)
    else:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)

    def init_fun(rng, input_shape):
        concat_input_shape = list(input_shape)
        # add time and coord channels; from 1 (torch) -> 0
        concat_input_shape[-1] += 3
        concat_input_shape = tuple(concat_input_shape)
        return init_fun_wrapped(rng, concat_input_shape)

    def apply_fun(params, inputs, **kwargs):
        x, t = inputs
        b, h, w, c = x.shape
        hh = np.arange(h).view(1, h, 1, 1).expand(b, h, w, 1)
        ww = np.arange(w).view(1, 1, w, 1).expand(b, h, w, 1)
        tt = t.view(1, 1, 1, 1).expand(b, h, w, 1)
        x_aug = np.concatenate([x, hh, ww, tt], axis=-1)
        out = apply_fun_wrapped(params, x_aug, **kwargs)
        return (out, t)

    return init_fun, apply_fun
Ejemplo n.º 4
0
def ConcatConv2D_v2(out_dim,
                    W_init=he_normal(),
                    b_init=normal(),
                    kernel=3,
                    stride=1,
                    padding=0,
                    dilation=1,
                    groups=1,
                    bias=True,
                    transpose=False):
    assert dilation == 1 and groups == 1
    if not transpose:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)
    else:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)

    def init_fun(rng, input_shape):
        k1, k2 = random.split(rng)
        output_shape_conv, params_conv = init_fun_wrapped(k1, input_shape)
        W_hyper_bias = W_init(k2, (1, out_dim))

        return output_shape_conv, (params_conv, W_hyper_bias)

    def apply_fun(params, inputs, **kwargs):
        x, t = inputs
        params_conv, W_hyper_bias = params
        out = apply_fun_wrapped(params_conv, x, **kwargs) + np.dot(
            t.view(1, 1), W_hyper_bias).view(
                1, 1, 1, -1)  # if ncwh stead of nhwc: .view(1, -1, 1, 1)
        return (out, t)

    return init_fun, apply_fun
Ejemplo n.º 5
0
def BlendConv2D(out_dim,
                W_init=he_normal(),
                b_init=normal(),
                kernel=3,
                stride=1,
                padding=0,
                dilation=1,
                groups=1,
                bias=True,
                transpose=False):
    assert dilation == 1 and groups == 1
    if not transpose:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)
    else:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)

    def init_fun(rng, input_shape):
        k1, k2 = random.split(rng)
        output_shape, params_f = init_fun_wrapped(k1, input_shape)
        _, params_g = init_fun_wrapped(k2, input_shape)
        return output_shape, (params_f, params_g)

    def apply_fun(params, inputs, **kwargs):
        x, t = inputs
        params_f, params_g = params
        f = apply_fun_wrapped(params_f, x)
        g = apply_fun_wrapped(params_g, x)
        out = f + (g - f) * t
        return (out, t)

    return init_fun, apply_fun
Ejemplo n.º 6
0
def ConcatConv2D(out_dim,
                 W_init=he_normal(),
                 b_init=normal(),
                 kernel=3,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 transpose=False):
    assert dilation == 1 and groups == 1
    if not transpose:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)
    else:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)

    def init_fun(rng, input_shape):  # note, input shapes only take x
        concat_input_shape = list(input_shape)
        concat_input_shape[-1] += 1  # add time channel dim
        concat_input_shape = tuple(concat_input_shape)
        return init_fun_wrapped(rng, concat_input_shape)

    def apply_fun(params, inputs, **kwargs):
        x, t = inputs
        tt = np.ones_like(x[:, :, :, :1]) * t
        xtt = np.concatenate([x, tt], axis=-1)
        out = apply_fun_wrapped(params, xtt, **kwargs)
        return (out, t)

    return init_fun, apply_fun