Ejemplo n.º 1
0
def generate_data_01():
    batch_size = 8
    input_shape = (batch_size, 4)

    def synth_batches():
        while True:
            images = npr.rand(*input_shape).astype("float32")
            yield images

    batches = synth_batches()
    inputs = next(batches)

    init_func, predict_func = stax.serial(
        HomotopyDense(out_dim=4, W_init=glorot_uniform(), b_init=normal()),
        HomotopyDense(out_dim=1, W_init=glorot_uniform(), b_init=normal()),
        Sigmoid,
    )

    ae_shape, ae_params = init_func(random.PRNGKey(0), input_shape)
    # assert ae_shape == input_shape
    bparam = [np.array([0.0], dtype=np.float64)]
    logits = predict_func(ae_params,
                          inputs,
                          bparam=bparam[0],
                          activation_func=sigmoid)
    loss = np.mean(
        (np.subtract(logits, logits))) + l2_norm(ae_params) + l2_norm(bparam)

    return inputs, logits, ae_params, bparam, init_func, predict_func
Ejemplo n.º 2
0
    def init_fun(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim,)
        k1, k2, k3, k4 = random.split(rng, 4)
        W_init = glorot_uniform()
        # projection
        W = W_init(k1, (input_shape[-1], out_dim))

        a_init = glorot_uniform()
        a1 = a_init(k2, (out_dim, 1))
        a2 = a_init(k3, (out_dim, 1))

        return output_shape, (W, a1, a2)
Ejemplo n.º 3
0
 def init_fun(rng, input_shape):
     output_shape = input_shape[:-1] + (out_dim, )
     k1, k2 = random.split(rng)
     W_init, b_init = glorot_uniform(), zeros
     W = W_init(k1, (input_shape[-1], out_dim))
     if bias:
         b = b_init(k2, (out_dim, ))
     else:
         b = None
     return output_shape, (W, b)
Ejemplo n.º 4
0
    def init_fun(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim,)
        k1, k2, k3, k4, k5 = random.split(rng, num=5)
        W_init, b_init = glorot_uniform(), zeros

        # used for the gating function
        W_t = W_init(k1, (input_shape[-1], out_dim))
        b_t = b_init(k2, (out_dim,))

        # used for the homogenous representation
        Theta = W_init(k3, (input_shape[-1], out_dim))

        # projection used in the outer infusion
        W_h = W_init(k4, (input_shape[-1], out_dim))
        # used only in the raw infusion
        W_x = W_init(k5, (1433, out_dim)) # hardcoded for Cora. should be an arg

        return output_shape, (W_t, b_t, Theta, W_h, W_x)
Ejemplo n.º 5
0
def BlockMaskedDense(num_blocks,
                     in_factor,
                     out_factor,
                     bias=True,
                     W_init=glorot_uniform()):
    """
    Module that implements a linear layer with block matrices with positive diagonal blocks.
    Moreover, it uses Weight Normalization (https://arxiv.org/abs/1602.07868) for stability.

    :param int num_blocks: Number of block matrices.
    :param int in_factor: number of rows in each block.
    :param int out_factor: number of columns in each block.
    :param W_init: initialization method for the weights.
    :return: an (`init_fn`, `update_fn`) pair.
    """
    input_dim, out_dim = num_blocks * in_factor, num_blocks * out_factor
    # construct mask_d, mask_o for formula (8) of Ref [1]
    # Diagonal block mask
    mask_d = np.identity(num_blocks)[..., None]
    mask_d = np.tile(mask_d,
                     (1, in_factor, out_factor)).reshape(input_dim, out_dim)
    # Off-diagonal block mask for upper triangular weight matrix
    mask_o = vec_to_tril_matrix(jnp.ones(num_blocks * (num_blocks - 1) // 2),
                                diagonal=-1).T[..., None]
    mask_o = jnp.tile(mask_o,
                      (1, in_factor, out_factor)).reshape(input_dim, out_dim)

    def init_fun(rng, input_shape):
        assert input_dim == input_shape[-1]
        *k1, k2, k3 = random.split(rng, num_blocks + 2)

        # Initialize each column block using W_init
        W = jnp.zeros((input_dim, out_dim))
        for i in range(num_blocks):
            W = ops.index_add(
                W, ops.index[:(i + 1) * in_factor,
                             i * out_factor:(i + 1) * out_factor],
                W_init(k1[i], ((i + 1) * in_factor, out_factor)))

        # initialize weight scale
        ws = jnp.log(uniform(1.)(k2, (out_dim, )))

        if bias:
            b = (uniform(1.)(k3, (out_dim, )) - 0.5) * (2 / jnp.sqrt(out_dim))
            params = (W, ws, b)
        else:
            params = (W, ws)
        return input_shape[:-1] + (out_dim, ), params

    def apply_fun(params, inputs, **kwargs):
        x, logdet = inputs
        if bias:
            W, ws, b = params
        else:
            W, ws = params

        # Form block weight matrix, making sure it's positive on diagonal!
        w = jnp.exp(W) * mask_d + W * mask_o

        # Compute norm of each column (i.e. each output features)
        w_norm = jnp.linalg.norm(w, axis=-2, keepdims=True)

        # Normalize weight and rescale
        w = jnp.exp(ws) * w / w_norm

        out = jnp.dot(x, w)
        if bias:
            out = out + b

        dense_logdet = ws + W - jnp.log(w_norm)
        # logdet of block diagonal
        dense_logdet = dense_logdet[mask_d.astype(bool)].reshape(
            num_blocks, in_factor, out_factor)
        if logdet is None:
            logdet = jnp.broadcast_to(dense_logdet,
                                      x.shape[:-1] + dense_logdet.shape)
        else:
            logdet = logmatmulexp(logdet, dense_logdet)
        return out, logdet

    return init_fun, apply_fun
Ejemplo n.º 6
0
 def __call__(self, key, shape, dtype=None):
     if dtype is None:
         dtype = "float32"
     initializer_fn = jax_initializers.glorot_uniform()
     return initializer_fn(key, shape, dtype)