예제 #1
0
    def __call__(self, z, train: bool = True):
        # Common arguments
        conv_kwargs = {
            'kernel_size': (4, 4),
            'strides': (2, 2),
            'padding': 'SAME',
            'use_bias': False,
            'kernel_init': he_normal()
        }
        norm_kwargs = {
            'use_running_average': not train,
            'momentum': 0.99,
            'epsilon': 0.001,
            'use_scale': True,
            'use_bias': True
        }

        z = np.reshape(z, (1, 1, self.zdim))

        # Layer 1
        z = nn.ConvTranspose(features=512,
                             kernel_size=(4, 4),
                             strides=(1, 1),
                             padding='VALID',
                             use_bias=False,
                             kernel_init=he_normal())(z)
        z = nn.BatchNorm(**norm_kwargs)(z)
        z = nn.leaky_relu(z, 0.2)

        # Layer 2
        z = nn.ConvTranspose(features=256, **conv_kwargs)(z)
        z = nn.BatchNorm(**norm_kwargs)(z)
        z = nn.leaky_relu(z, 0.2)

        # Layer 3
        z = nn.ConvTranspose(features=128, **conv_kwargs)(z)
        z = nn.BatchNorm(**norm_kwargs)(z)
        z = nn.leaky_relu(z, 0.2)

        # Layer 4
        z = nn.ConvTranspose(features=64, **conv_kwargs)(z)
        z = nn.BatchNorm(**norm_kwargs)(z)
        z = nn.leaky_relu(z, 0.2)

        # Layer 5
        z = nn.ConvTranspose(features=1,
                             kernel_size=(4, 4),
                             strides=(2, 2),
                             padding='SAME',
                             use_bias=False,
                             kernel_init=nn.initializers.xavier_normal())(z)
        # x = nn.sigmoid(z)
        x = nn.softplus(z)

        return jnp.rot90(np.squeeze(x), k=2)  # Rotate to match TF output
예제 #2
0
    def __call__(self, x, train: bool = True):
        # Common arguments
        kwargs = {
            'kernel_size': (4, 4),
            'strides': (2, 2),
            'padding': 'SAME',
            'use_bias': False,
            'kernel_init': he_normal()
        }

        # x = np.reshape(x, (64, 64, 1))
        x = x[..., None]

        # Layer 1
        x = nn.Conv(features=64, **kwargs)(x)
        x = nn.leaky_relu(x, 0.2)

        # Layer 2
        x = nn.Conv(features=128, **kwargs)(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.leaky_relu(x, 0.2)

        # Layer 3
        x = nn.Conv(features=256, **kwargs)(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.leaky_relu(x, 0.2)

        # Layer 4
        x = nn.Conv(features=512, **kwargs)(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.leaky_relu(x, 0.2)

        # Layer 5
        x = nn.Conv(features=4096,
                    kernel_size=(4, 4),
                    strides=(1, 1),
                    padding='VALID',
                    use_bias=False,
                    kernel_init=he_normal())(x)
        x = nn.leaky_relu(x, 0.2)

        # Flatten
        x = x.flatten()

        # Predict latent variables
        z_mean = nn.Dense(features=self.zdim)(x)
        z_logvar = nn.Dense(features=self.zdim)(x)

        return z_mean, z_logvar
예제 #3
0
def ConcatSquashLinear(out_dim, W_init=he_normal(), b_init=normal()):
    """ y = Sigmoid(at + c)(Wx + b) + dt. Note: he_normal only takes multi dim.
    """
    def init_fun(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim, )
        k1, k2, k3, k4, k5 = random.split(rng, 5)
        W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim, ))
        w_t, w_tb = b_init(k3, (out_dim, )), b_init(k4, (out_dim, ))
        b_t = b_init(k5, (out_dim, ))
        return output_shape, (W, b, w_t, w_tb, b_t)

    def apply_fun(params, inputs, **kwargs):
        x, t = inputs
        W, b, w_t, w_tb, b_t = params

        # (W.xtt + b) *
        out = np.dot(x, W) + b
        # sigmoid(a.t + c)  +
        out *= jax.nn.sigmoid(w_t * t + w_tb)
        # d.t
        out += b_t * t

        return (out, t)

    return init_fun, apply_fun
예제 #4
0
def IgnoreConv2D(out_dim,
                 W_init=he_normal(),
                 b_init=normal(),
                 kernel=3,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 transpose=False):
    assert dilation == 1 and groups == 1
    if not transpose:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)
    else:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)

    def apply_fun(params, inputs, **kwargs):
        x, t = inputs
        out = apply_fun_wrapped(params, x, **kwargs)
        return (out, t)

    return init_fun_wrapped, apply_fun_wrapped
예제 #5
0
def create_q_net(
    obs_dim, action_dim, rngkey=jax.random.PRNGKey(0)
) -> TT.Tuple[RT.NNParams, RT.NNParamsFn]:
    q_init, q_fn = serial(
        Dense(64, he_normal(), zeros),
        Relu,
        Dense(64, he_normal(), zeros),
        Relu,
        Dense(action_dim, he_normal(), zeros),
    )
    output_shape, q_params = q_init(rngkey, (1, obs_dim + action_dim))

    @jit
    def q_fn2(q, S, A):
        return q_fn(q, jnp.hstack([S, A]))

    return q_params, q_fn2
예제 #6
0
def create_pi_net(
    obs_dim: int, action_dim: int, rngkey=jax.random.PRNGKey(0)
) -> TT.Tuple[RT.NNParams, RT.NNParamsFn]:
    pi_init, pi_fn = serial(
        Dense(64, he_normal(), zeros),
        Relu,
        FanOut(2),
        parallel(
            serial(
                Dense(64, he_normal(), zeros),
                Relu,
                Dense(action_dim, he_normal(), zeros),
            ),
            serial(
                Dense(64, he_normal(), zeros),
                Relu,
                Dense(action_dim, he_normal(), zeros),
            ),
        ),
    )
    output_shape, pi_params = pi_init(rngkey, (1, obs_dim))
    pi_fn = jit(pi_fn)
    return pi_params, pi_fn
예제 #7
0
def IgnoreLinear(out_dim, W_init=he_normal(), b_init=normal()):
    """ y = Wx + b
    """
    def init_fun(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim, )
        k1, k2 = random.split(rng)
        W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim, ))
        return output_shape, (W, b)

    def apply_fun(params, inputs, **kwargs):
        x, t = inputs
        W, b = params
        return (np.dot(x, W) + b, t)

    return init_fun, apply_fun
예제 #8
0
def ConcatSquashConv2D(out_dim,
                       W_init=he_normal(),
                       b_init=normal(),
                       kernel=3,
                       stride=1,
                       padding=0,
                       dilation=1,
                       groups=1,
                       bias=True,
                       transpose=False):
    assert dilation == 1 and groups == 1
    if not transpose:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)
    else:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)

    def init_fun(rng, input_shape):
        k1, k2, k3, k4 = random.split(rng, 4)
        output_shape_conv, params_conv = init_fun_wrapped(k1, input_shape)
        W_hyper_gate, b_hyper_gate = W_init(k2, (1, out_dim)), b_init(
            k3, (out_dim, ))
        W_hyper_bias = W_init(k4, (1, out_dim))
        return output_shape_conv, (params_conv, W_hyper_gate, b_hyper_gate,
                                   W_hyper_bias)

    def apply_fun(params, inputs, **kwargs):
        x, t = inputs
        params_conv, W_hyper_gate, b_hyper_gate, W_hyper_bias = params
        conv_out = apply_fun_wrapped(params_conv, x, **kwargs)
        gate_out = jax.nn.sigmoid(
            np.dot(t.view(1, 1), W_hyper_gate) + b_hyper_gate).view(
                1, 1, 1, -1)
        bias_out = np.dot(t.view(1, 1), W_hyper_bias).view(1, 1, 1, -1)
        out = conv_out * gate_out + bias_out
        return (out, t)

    return init_fun, apply_fun
예제 #9
0
def ConcatCoordConv2D(out_dim,
                      W_init=he_normal(),
                      b_init=normal(),
                      kernel=3,
                      stride=1,
                      padding=0,
                      dilation=1,
                      groups=1,
                      bias=True,
                      transpose=False):
    assert dilation == 1 and groups == 1
    if not transpose:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)
    else:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)

    def init_fun(rng, input_shape):
        concat_input_shape = list(input_shape)
        # add time and coord channels; from 1 (torch) -> 0
        concat_input_shape[-1] += 3
        concat_input_shape = tuple(concat_input_shape)
        return init_fun_wrapped(rng, concat_input_shape)

    def apply_fun(params, inputs, **kwargs):
        x, t = inputs
        b, h, w, c = x.shape
        hh = np.arange(h).view(1, h, 1, 1).expand(b, h, w, 1)
        ww = np.arange(w).view(1, 1, w, 1).expand(b, h, w, 1)
        tt = t.view(1, 1, 1, 1).expand(b, h, w, 1)
        x_aug = np.concatenate([x, hh, ww, tt], axis=-1)
        out = apply_fun_wrapped(params, x_aug, **kwargs)
        return (out, t)

    return init_fun, apply_fun
예제 #10
0
def Dense(out_dim,
          W_init=he_normal(),
          b_init=normal(),
          rho_init=partial(const, c=-5)):
    """Layer constructor function for a dense (fully-connected) Bayesian linear layer."""
    def init_fun(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim, )
        k1, k2, k3, k4 = random.split(rng, 4)
        W_mu, b_mu = W_init(k1, (input_shape[-1], out_dim)), b_init(
            k2, (out_dim, ))
        W_rho, b_rho = rho_init((input_shape[-1], out_dim)), rho_init(
            (out_dim, ))
        return output_shape, (W_mu, b_mu, W_rho, b_rho)

    def apply_fun(params, inputs, rng, **kwargs):
        # print(inputs[0][0])
        inputs, kl = inputs
        # kl = 0
        subkeys = random.split(rng, 2)

        W_mu, b_mu, W_rho, b_rho = params
        W_eps = random.normal(subkeys[0], W_mu.shape)
        b_eps = random.normal(subkeys[1], b_mu.shape)
        # q dist
        W_std = np.exp(W_rho)
        b_std = np.exp(b_rho)

        W = W_eps * W_std + W_mu
        b = b_eps * b_std + b_mu

        # Bayes by Backprop training
        W_kl = normal_kldiv(W_mu, 0., W_rho, 0.)
        b_kl = normal_kldiv(b_mu, 0., b_rho, 0.)
        W_kl, b_kl = np.sum(W_kl), np.sum(b_kl)

        kl_loss = W_kl + b_kl
        kl_loss = kl_loss + np.array(
            kl)  # TODO: why do we get compatibility issues?
        # print(W.shape)

        return (np.dot(inputs, W) + b, kl_loss)

    return init_fun, apply_fun
예제 #11
0
def ConcatConv2D_v2(out_dim,
                    W_init=he_normal(),
                    b_init=normal(),
                    kernel=3,
                    stride=1,
                    padding=0,
                    dilation=1,
                    groups=1,
                    bias=True,
                    transpose=False):
    assert dilation == 1 and groups == 1
    if not transpose:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)
    else:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)

    def init_fun(rng, input_shape):
        k1, k2 = random.split(rng)
        output_shape_conv, params_conv = init_fun_wrapped(k1, input_shape)
        W_hyper_bias = W_init(k2, (1, out_dim))

        return output_shape_conv, (params_conv, W_hyper_bias)

    def apply_fun(params, inputs, **kwargs):
        x, t = inputs
        params_conv, W_hyper_bias = params
        out = apply_fun_wrapped(params_conv, x, **kwargs) + np.dot(
            t.view(1, 1), W_hyper_bias).view(
                1, 1, 1, -1)  # if ncwh stead of nhwc: .view(1, -1, 1, 1)
        return (out, t)

    return init_fun, apply_fun
예제 #12
0
def BlendConv2D(out_dim,
                W_init=he_normal(),
                b_init=normal(),
                kernel=3,
                stride=1,
                padding=0,
                dilation=1,
                groups=1,
                bias=True,
                transpose=False):
    assert dilation == 1 and groups == 1
    if not transpose:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)
    else:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)

    def init_fun(rng, input_shape):
        k1, k2 = random.split(rng)
        output_shape, params_f = init_fun_wrapped(k1, input_shape)
        _, params_g = init_fun_wrapped(k2, input_shape)
        return output_shape, (params_f, params_g)

    def apply_fun(params, inputs, **kwargs):
        x, t = inputs
        params_f, params_g = params
        f = apply_fun_wrapped(params_f, x)
        g = apply_fun_wrapped(params_g, x)
        out = f + (g - f) * t
        return (out, t)

    return init_fun, apply_fun
예제 #13
0
def ConcatConv2D(out_dim,
                 W_init=he_normal(),
                 b_init=normal(),
                 kernel=3,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 transpose=False):
    assert dilation == 1 and groups == 1
    if not transpose:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)
    else:
        init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose(
            dimension_numbers,
            out_chan=out_dim,
            filter_shape=(kernel, kernel),
            strides=(stride, stride),
            padding=padding)

    def init_fun(rng, input_shape):  # note, input shapes only take x
        concat_input_shape = list(input_shape)
        concat_input_shape[-1] += 1  # add time channel dim
        concat_input_shape = tuple(concat_input_shape)
        return init_fun_wrapped(rng, concat_input_shape)

    def apply_fun(params, inputs, **kwargs):
        x, t = inputs
        tt = np.ones_like(x[:, :, :, :1]) * t
        xtt = np.concatenate([x, tt], axis=-1)
        out = apply_fun_wrapped(params, xtt, **kwargs)
        return (out, t)

    return init_fun, apply_fun
예제 #14
0
def ConcatLinear(out_dim, W_init=he_normal(), b_init=normal()):
    """ y = Wx + b + at
    """
    def init_fun(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim, )
        k1, k2 = random.split(rng)
        W, b = W_init(k1,
                      (input_shape[-1] + 1, out_dim)), b_init(k2, (out_dim, ))
        return output_shape, (W, b)

    def apply_fun(params, inputs, **kwargs):
        x, t = inputs
        W, b = params

        # concatenate t onto the inputs
        tt = t.reshape([-1] * (x.ndim - 1) + [1])  # single batch example
        # i.e. [:, :, ..., :, :1] column vector
        tt = np.tile(tt, x.shape[:-1] + (1, ))
        xtt = np.concatenate([x, tt], axis=-1)

        return (np.dot(xtt, W) + b, t)

    return init_fun, apply_fun
예제 #15
0
 def __call__(self, key, shape, dtype=None):
     if dtype is None:
         dtype = "float32"
     initializer_fn = jax_initializers.he_normal()
     return initializer_fn(key, shape, dtype)
예제 #16
0
def GCNLayer(out_dim,
             activation=relu,
             bias=True,
             normalize=True,
             batch_norm=False,
             dropout=0.0,
             W_init=he_normal(),
             b_init=normal()):
    r"""Single GCN layer from `Semi-Supervised Classification with Graph Convolutional Networks
    <https://arxiv.org/abs/1609.02907>`

    Parameters
    ----------
    out_dim : int
        Number of output node features.
    activation : Function
        activation function, default to be relu function.
    bias : bool
        Whether to add bias after affine transformation, default to be True.
    normalize : bool
        Whether to normalize the adjacency matrix or not, default to be True.
    batch_norm : bool
        Whetehr to use BatchNormalization or not, default to be False.
    dropout : float
        The probability for dropout, default to 0.0.
    W_init : initialize function for weight
        Default to be He normal distribution.
    b_init : initialize function for bias
        Default to be normal distribution.

    Returns
    -------
    init_fun : Function
        Initializes the parameters of the layer.
    apply_fun : Function
        Defines the forward computation function.
    """

    _, drop_fun = Dropout(dropout)
    batch_norm_init, batch_norm_fun = BatchNorm()

    def init_fun(rng, input_shape):
        """Initialize parameters.

        Parameters
        ----------
        rng : PRNGKey
            rng is a value for generating random values.
        input_shape : (batch_size, N, M1)
            The shape of input (input node features).
            N is the total number of nodes in the batch of graphs.
            M1 is the input node feature size.

        Returns
        -------
        output_shape : (batch_size, N, M2)
            The shape of output (new node features).
            M2 is the new node feature size and equal to out_dim.
        params: Tuple (W, b, batch_norm_param)
            W is a weight and b is a bias.
            W : ndarray of shape (N, M2) or None
            b : ndarray of shape (M2,)
            batch_norm_param : Tuple (beta, gamma) or None
        """
        output_shape = input_shape[:-1] + (out_dim, )
        k1, k2, k3 = random.split(rng, 3)
        W = W_init(k1, (input_shape[-1], out_dim))
        b = b_init(k2, (out_dim, )) if bias else None
        batch_norm_param = None
        if batch_norm:
            output_shape, batch_norm_param = batch_norm_init(k3, output_shape)
        return output_shape, (W, b, batch_norm_param)

    def apply_fun(params, node_feats, adj, rng, is_train):
        """Update node representations.

        Parameters
        ----------
        node_feats : ndarray of shape (batch_size, N, M1)
            Batched input node features.
            N is the total number of nodes in the batch of graphs.
            M1 is the input node feature size.
        adj : ndarray of shape (batch_size, N, N)
            Batched adjacency matrix.
        rng : PRNGKey
            rng is a value for generating random values
        is_train : bool
            Whether the model is training or not.

        Returns
        -------
        new_node_feats : ndarray of shape (batch_size, N, M2)
            Batched new node features.
            M2 is the new node feature size and equal to out_dim.
        """
        W, b, batch_norm_param = params

        if normalize:
            # A' = A + I, where I is the identity matrix
            # D': diagonal node degree matrix of A'
            # H' = D'^(-1/2) × A' × D'^(-1/2) × H × W
            def node_update_func(node_feats, adj):
                adj = adj + jnp.eye(len(adj))
                deg = jnp.sum(adj, axis=1)
                deg_mat = jnp.diag(jnp.where(deg > 0, deg**(-0.5), 0))
                normalized_adj = jnp.dot(deg_mat, jnp.dot(adj, deg_mat))
                return jnp.dot(normalized_adj, jnp.dot(node_feats, W))
        else:
            # H' = A × H × W
            def node_update_func(node_feats, adj):
                return jnp.dot(adj, jnp.dot(node_feats, W))

        # batched operation for updating node features
        new_node_feats = vmap(node_update_func)(node_feats, adj)

        if bias:
            new_node_feats += b
        new_node_feats = activation(new_node_feats)
        if dropout != 0.0:
            rng, key = random.split(rng)
            new_node_feats = drop_fun(None, new_node_feats, is_train, rng=key)
        if batch_norm:
            new_node_feats = batch_norm_fun(batch_norm_param, new_node_feats)
        return new_node_feats

    return init_fun, apply_fun