Exemplo n.º 1
0
from jax.experimental import optimizers
import jax.numpy as jnp

from dataset.make_data import equation_of_motion, rk4_step, solve_lagrangian
from dataset.plot import normalize_dp
from visualization import plot_loss

TRAIN_DATASET_PATH = "../data/train_data.pickle"
TEST_DATASET_PATH = "../data/test_data.pickle"
LOG_DIR = "./logs"

# build a neural network model
init_random_params, nn_forward_fn = stax.serial(
    stax.Dense(128),
    stax.Softplus,
    stax.Dense(128),
    stax.Softplus,
    stax.Dense(1),
)


# replace the lagrangian with a parameteric model
def learned_lagrangian(params):
    def lagrangian(q, q_t):
        assert q.shape == (2, )
        state = normalize_dp(jnp.concatenate([q, q_t]))
        return jnp.squeeze(nn_forward_fn(params, state), axis=-1)

    return lagrangian

Exemplo n.º 2
0
def repeat(layer, num_repeats):
    """Repeats layers serially num_repeats times."""
    if num_repeats < 1:
        raise ValueError('Repeat combinator num_repeats must be >= 1.')
    layers = num_repeats * (layer, )
    return stax.serial(*layers)
Exemplo n.º 3
0
    def initialize_parametric_nonlinearity(self,
                                           init_to='exponential',
                                           method=None,
                                           params_dict=None):

        if method is None:  # if no methods specified, use defaults.
            # this piece of code is quite redundant.
            # need to refactor.
            if hasattr(self, 'nonlinearity'):
                method = self.nonlinearity
            else:
                method = self.filter_nonlinearity
        else:  # overwrite the default nonlinearity
            if hasattr(self, 'nonlinearity'):
                self.nonlinearity = method
            else:
                self.filter_nonlinearity = method
                self.output_nonlinearity = method

        # prepare data
        if params_dict is None:
            params_dict = {}
        xrange = params_dict['xrange'] if 'xrange' in params_dict else 5
        nx = params_dict['nx'] if 'nx' in params_dict else 1000
        x0 = np.linspace(-xrange, xrange, nx)
        if init_to == 'exponential':
            y0 = np.exp(x0)

        elif init_to == 'softplus':
            y0 = softplus(x0)

        elif init_to == 'relu':
            y0 = relu(x0)

        elif init_to == 'nonparametric':
            y0 = self.fnl_nonparametric(x0)

        elif init_to == 'gaussian':
            import scipy.signal
            y0 = scipy.signal.gaussian(nx, nx / 10)

        # fit nonlin
        if method == 'spline':
            smooth = params_dict['smooth'] if 'smooth' in params_dict else 'cr'
            df = params_dict['df'] if 'df' in params_dict else 7
            if smooth == 'cr':
                X = cr(x0, df)
            elif smooth == 'cc':
                X = cc(x0, df)
            elif smooth == 'bs':
                deg = params_dict['degree'] if 'degree' in params_dict else 3
                X = bs(x0, df, deg)

            opt_params = np.linalg.pinv(X.T @ X) @ X.T @ y0

            self.nl_basis = X

            def _nl(opt_params, x_new):
                return np.maximum(interp1d(x0, X @ opt_params)(x_new), 0)

        elif method == 'nn':

            def loss(params, data):
                x = data['x']
                y = data['y']
                yhat = _predict(params, x)
                return np.mean((y - yhat)**2)

            @jit
            def step(i, opt_state, data):
                p = get_params(opt_state)
                g = grad(loss)(p, data)
                return opt_update(i, g, opt_state)

            random_seed = params_dict[
                'random_seed'] if 'random_seed' in params_dict else 2046
            key = random.PRNGKey(random_seed)

            step_size = params_dict[
                'step_size'] if 'step_size' in params_dict else 0.01
            layer_sizes = params_dict[
                'layer_sizes'] if 'layer_sizes' in params_dict else [
                    10, 10, 1
                ]
            layers = []
            for layer_size in layer_sizes:
                layers.append(Dense(layer_size))
                layers.append(BatchNorm(axis=(0, 1)))
                layers.append(Relu)
            else:
                layers.pop(-1)

            init_random_params, _predict = stax.serial(*layers)

            num_subunits = params_dict[
                'num_subunits'] if 'num_subunits' in params_dict else 1
            _, init_params = init_random_params(key, (-1, num_subunits))

            opt_init, opt_update, get_params = optimizers.adam(step_size)
            opt_state = opt_init(init_params)

            num_iters = params_dict[
                'num_iters'] if 'num_iters' in params_dict else 1000
            if num_subunits == 1:
                data = {'x': x0.reshape(-1, 1), 'y': y0.reshape(-1, 1)}
            else:
                data = {
                    'x': np.vstack([x0 for i in range(num_subunits)]).T,
                    'y': y0.reshape(-1, 1)
                }

            for i in range(num_iters):
                opt_state = step(i, opt_state, data)
            opt_params = get_params(opt_state)

            def _nl(opt_params, x_new):
                if len(x_new.shape) == 1:
                    x_new = x_new.reshape(-1, 1)
                return np.maximum(_predict(opt_params, x_new), 0)

        self.nl_xrange = x0
        self.nl_params = opt_params
        self.fnl_fitted = _nl
def main():
    total_time = 20.0
    gamma = 1.0
    x_dim = 2
    z_dim = 32
    rng = random.PRNGKey(0)

    x0 = jp.array([2.0, 1.0])

    ### Set up the problem/environment
    # xdot = Ax + Bu
    # u = - Kx
    # cost = xQx + uRu + 2xNu
    A, B, Q, R, N = fixed_env(x_dim)
    dynamics_fn = lambda x, u: A @ x + B @ u
    cost_fn = lambda x, u: x.T @ Q @ x + u.T @ R @ u + 2 * x.T @ N @ u
    policy_loss = policy_integrate_cost(x_dim, z_dim, dynamics_fn, cost_fn,
                                        gamma)

    ### Solve the Riccatti equation to get the infinite-horizon optimal solution.
    K, _, _ = control.lqr(A, B, Q, R, N)
    K = jp.array(K)

    t0 = time.time()
    opt_y_fwd, opt_y_bwd = policy_loss(lambda _, x: -K @ x)(None, x0,
                                                            total_time)
    opt_cost = opt_y_fwd[1, 0]
    print(f"opt_cost = {opt_cost} in {time.time() - t0}s")
    print(opt_y_fwd)
    print(opt_y_bwd)
    print(f"l2 error: {jp.sqrt(jp.sum((opt_y_fwd - opt_y_bwd)**2))}")

    ### Set up the learned policy model.
    policy_init, policy = stax.serial(
        Dense(64),
        Tanh,
        Dense(x_dim),
    )

    rng_init_params, rng = random.split(rng)
    _, init_policy_params = policy_init(rng_init_params, (x_dim, ))
    opt = make_optimizer(optimizers.adam(1e-3))(init_policy_params)
    runny_run = jit(policy_loss(policy))

    ### Main optimization loop.
    costs = []
    bwd_errors = []
    for i in range(5000):
        t0 = time.time()
        (y_fwd, y_bwd), vjp = jax.vjp(runny_run, opt.value, x0, total_time)
        cost = y_fwd[1, 0]

        y_fwd_bar = jax.ops.index_update(jp.zeros_like(y_fwd), (1, 0), 1)
        g, _, _ = vjp((y_fwd_bar, jp.zeros_like(y_bwd)))
        opt = opt.update(g)

        bwd_err = jp.sqrt(jp.sum((y_fwd - y_bwd)**2))
        bwd_errors.append(bwd_err)

        print(f"Episode {i}:")
        print(f"    excess cost = {cost - opt_cost}")
        print(f"    bwd error = {bwd_err}")
        print(f"    elapsed = {time.time() - t0}")
        costs.append(float(cost))

    print(f"Opt solution cost from starting point: {opt_cost}")

    ### Plot performance per iteration, incl. average optimal policy performance.
    _, ax1 = plt.subplots()
    ax1.set_xlabel("Iteration")
    ax1.set_ylabel("Cost", color="tab:blue")
    ax1.set_yscale("log")
    ax1.tick_params(axis="y", labelcolor="tab:blue")
    ax1.plot(costs, color="tab:blue")
    plt.axhline(opt_cost, linestyle="--", color="gray")

    ax2 = ax1.twinx()
    ax2.set_ylabel("Backward solve L2 error", color="tab:red")
    ax2.set_yscale("log")
    ax2.tick_params(axis="y", labelcolor="tab:red")
    ax2.plot(bwd_errors, color="tab:red")
    plt.title("ODE control of LQR problem")

    blt.show()
Exemplo n.º 5
0
 def testSerialComposeLayersShape(self, input_shape, spec):
     init_fun, apply_fun = stax.serial(*spec)
     _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
def loss(params, batch):
    inputs, targets = batch
    preds = predict(params, inputs)
    return -np.mean(np.sum(preds * targets, axis=1))


def accuracy(params, batch):
    inputs, targets = batch
    target_class = np.argmax(targets, axis=1)
    predicted_class = np.argmax(predict(params, inputs), axis=1)
    return np.mean(predicted_class == target_class)


init_random_params, predict = stax.serial(
    Conv(10, (5, 5), (1, 1)), Activator,
    MaxPool((4, 4)), Flatten,
    Dense(10), LogSoftmax)

if __name__ == "__main__":
    rng = random.PRNGKey(0)
    
    step_size = 0.001
    num_epochs = 10
    batch_size = 128
    momentum_mass = 0.9

    # input shape for CNN
    input_shape = (-1, 28, 28, 1)
    
    # training/test split
    (train_images, train_labels), (test_images, test_labels) = mnist_data.tiny_mnist(flatten=False)
Exemplo n.º 7
0
import jax
import jax.experimental.stax as stax
import jax.experimental.optimizers as opt
from functools import partial
import numpy as np
import matplotlib.pyplot as plt

###############################################
########## SETTING UP NEURAL NETWORK ##########
###############################################

net_init, net_apply = stax.serial(
    stax.Dense(40), stax.Relu,
    stax.Dense(40), stax.Relu,
    stax.Dense(1)
)

input_shape = (-1, 1,)
output_shape, net_params = net_init(input_shape)

def loss(params, inputs, targets):
    predictions = net_apply(params, inputs)
    return jax.numpy.mean((targets - predictions)**2)

xrange_inputs = jax.numpy.linspace((-5, 5, 100).reshape((100, 1)))
targets = jax.numpy.sin(xrange_inputs)
predictions = jax.vmap(partial(net_apply, net_params))(xrange_inputs)
losses = jax.vmap(partial(loss, net_params))(xrange_inputs, targets)

####################################################
########## PLOTTING UNINITIALIZED NETWORK ##########
Exemplo n.º 8
0
def main():
    rng = random.PRNGKey(0)
    x_dim = 2
    T = 20.0

    policy_init, policy = stax.serial(
        Dense(64),
        Tanh,
        Dense(x_dim),
    )

    x0 = jnp.ones(x_dim)

    A, B, Q, R, N = fixed_env(x_dim)
    print("System dynamics:")
    print(f"  A = {A}")
    print(f"  B = {B}")
    print(f"  Q = {Q}")
    print(f"  R = {R}")
    print(f"  N = {N}")
    print()

    dynamics_fn = lambda x, u: A @ x + B @ u
    cost_fn = lambda x, u: x.T @ Q @ x + u.T @ R @ u + 2 * x.T @ N @ u

    ### Evaluate LQR solution to get a sense of optimal cost.
    K, _, _ = control.lqr(A, B, Q, R, N)
    K = jnp.array(K)
    opt_loss, _opt_K_grad = policy_loss_and_grad(dynamics_fn, cost_fn, T,
                                                 lambda KK, x: -KK @ x)(x0, K)
    # This is true for longer time horizons, but not true for shorter time
    # horizons due to the LQR solution being an infinite-time solution.
    # assert jnp.allclose(opt_K_grad, 0)

    ### Training loop.
    rng_init_params, rng = random.split(rng)
    _, init_policy_params = policy_init(rng_init_params, (x_dim, ))
    opt = make_optimizer(optimizers.adam(1e-3))(init_policy_params)
    loss_and_grad = policy_loss_and_grad(dynamics_fn, cost_fn, T, policy)

    loss_per_iter = []
    elapsed_per_iter = []
    for iteration in range(10000):
        t0 = time.time()
        loss, g = loss_and_grad(x0, opt.value)
        opt = opt.update(g)
        elapsed = time.time() - t0

        loss_per_iter.append(loss)
        elapsed_per_iter.append(elapsed)

        print(f"Iteration {iteration}")
        print(f"    excess loss = {loss - opt_loss}")
        print(f"    elapsed = {elapsed}")

    blt.remember({
        "loss_per_iter": loss_per_iter,
        "elapsed_per_iter": elapsed_per_iter,
        "opt_loss": opt_loss
    })

    _, ax1 = plt.subplots()
    ax1.set_xlabel("Iteration")
    ax1.set_ylabel("Cost", color="tab:blue")
    ax1.set_yscale("log")
    ax1.tick_params(axis="y", labelcolor="tab:blue")
    ax1.plot(loss_per_iter, color="tab:blue", label="Total rollout cost")
    plt.axhline(opt_loss, linestyle="--", color="gray")
    ax1.legend(loc="upper left")
    plt.title("Combined fwd-bwd BVP problem")
    blt.show()
Exemplo n.º 9
0
def fcnn(
    output_dim: int,
    layers=2,
    units=256,
    skips=False,
    output_activation=None,
    first_layer_factor=None,
    parallel_with=None,
):
    """
    Creates init and apply functions for a fully-connected NN with a Gaussian
    likelihood.

    :param output_dim: Number of output dimensions
    :param first_layer_factor: Scaling factor for first dense layer initialization.
    """

    first_layer_factor = 10.0 if first_layer_factor is None else first_layer_factor

    activation = stax.elementwise(np.sin)
    block = (skip_connect(
        stax.serial(Dense(units, W_init=siren_w_init()()),
                    activation)) if skips else stax.serial(
                        Dense(units, W_init=siren_w_init()()), activation))

    layer_list = [
        stax.serial(
            Dense(
                units,
                W_init=siren_w_init(factor=first_layer_factor)(),
                b_init=jax.nn.initializers.normal(stddev=2.0 * math.pi),
            ),
            activation,
        ),
        *([block] * (layers - 1)),
        Dense(output_dim),  # No skips on the last one!
    ]
    if output_activation is not None:
        layer_list.append(output_activation)
    if parallel_with is None:
        _init_fun, _apply_fun = stax.serial(*layer_list)
    else:
        _init_fun, _apply_fun = parallel(stax.serial(*layer_list),
                                         parallel_with)

    def init_fun(rng, input_shape):
        output_shape, net_params = _init_fun(rng, input_shape)
        params = {"net": net_params, "raw_noise": np.array(-2.0)}
        return output_shape, params

    # Conform to API:
    @t_wrapper
    def apply_fun(params, rng, inputs):
        return _apply_fun(params["net"], inputs)

    @t_wrapper
    def gaussian_fun(params, rng, inputs):
        pred_mean = apply_fun(params, None, inputs)
        pred_std = params["noise"] * np.ones(pred_mean.shape)
        return pred_mean, pred_std

    @t_wrapper
    def loss_fun(params,
                 rng,
                 data,
                 batch_size=None,
                 n=None,
                 loss_type="nlp",
                 reduce="sum"):
        """
        :param batch_size: How large a batch to subselect from the provided data
        :param n: The total size of the dataset (to multiply batch estimate by)
        """
        assert loss_type in ("nlp", "mse")
        inputs, targets = data
        n = inputs.shape[0] if n is None else n
        if batch_size is not None:
            rng, rng_batch = random.split(rng)
            i = random.permutation(rng_batch, n)[:batch_size]
            inputs, targets = inputs[i], targets[i]
        preds = apply_fun(params, rng, inputs).squeeze()
        mean_loss = (
            -norm.logpdf(targets.squeeze(), preds, params["noise"]).mean()
            if loss_type == "nlp" else np.power(targets.squeeze() -
                                                preds, 2).mean())
        if reduce == "sum":
            loss = n * mean_loss
        elif reduce == "mean":
            loss = mean_loss
        return loss

    def sample_fun_fun(rng, params):
        def f(x):
            return apply_fun(params, rng, x)

        return f

    return {
        "init": init_fun,
        "apply": apply_fun,
        "gaussian": gaussian_fun,
        "loss": loss_fun,
        "sample_function": sample_fun_fun,
    }
Exemplo n.º 10
0
 def create_surrogate(self):
     surrogate_init, surrogate = stax.serial(Dense(200),
                                             Relu, Dense(200), Relu,
                                             Dense(200), Relu, Dense(1))
     return surrogate, surrogate_init
Exemplo n.º 11
0
def AutoregressiveNN(input_dim,
                     hidden_dims,
                     param_dims=[1, 1],
                     permutation=None,
                     skip_connections=False,
                     nonlinearity=stax.Relu):
    """
    An implementation of a MADE-like auto-regressive neural network.

    Similar to the purely functional layer implemented in jax.experimental.stax,
    the `AutoregressiveNN` class has `init_fun` and `apply_fun` methods,
    where `init_fun` takes an rng_key key and an input shape and returns an
    (output_shape, params) pair, and `apply_fun` takes params and inputs
    and applies the layer.

    :param input_dim: the dimensionality of the input
    :type input_dim: int
    :param hidden_dims: the dimensionality of the hidden units per layer
    :type hidden_dims: list[int]
    :param param_dims: shape the output into parameters of dimension (p_n, input_dim) for p_n in param_dims
        when p_n > 1 and dimension (input_dim) when p_n == 1. The default is [1, 1], i.e. output two parameters
        of dimension (input_dim), which is useful for inverse autoregressive flow.
    :type param_dims: list[int]
    :param permutation: an optional permutation that is applied to the inputs and controls the order of the
        autoregressive factorization. in particular for the identity permutation the autoregressive structure
        is such that the Jacobian is triangular. Defaults to identity permutation.
    :type permutation: array of ints
    :param bool skip_connection: whether to add skip connections from the input to the output.
    :type skip_connections: bool
    :param nonlinearity: The nonlinearity to use in the feedforward network such as ReLU. Note that no
        nonlinearity is applied to the final network output, so the output is an unbounded real number.
        defaults to ReLU.
    :type nonlinearity: callable.
    :return: a tuple (init_fun, apply_fun)

    Reference:

    MADE: Masked Autoencoder for Distribution Estimation [arXiv:1502.03509]
    Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle
    """
    output_multiplier = sum(param_dims)
    all_ones = (np.array(param_dims) == 1).all()

    # Calculate the indices on the output corresponding to each parameter
    ends = np.cumsum(np.array(param_dims), axis=0)
    starts = np.concatenate((np.zeros(1), ends[:-1]))
    param_slices = [slice(int(s), int(e)) for s, e in zip(starts, ends)]

    # Hidden dimension must be not less than the input otherwise it isn't
    # possible to connect to the outputs correctly
    for h in hidden_dims:
        if h < input_dim:
            raise ValueError(
                'Hidden dimension must not be less than input dimension.')

    if permutation is None:
        permutation = jnp.arange(input_dim)

    # Create masks
    masks, mask_skip = create_mask(input_dim=input_dim,
                                   hidden_dims=hidden_dims,
                                   permutation=permutation,
                                   output_dim_multiplier=output_multiplier)

    main_layers = []
    # Create masked layers
    for i, mask in enumerate(masks):
        main_layers.append(MaskedDense(mask))
        if i < len(masks) - 1:
            main_layers.append(nonlinearity)

    if skip_connections:
        net_init, net = stax.serial(
            stax.FanOut(2),
            stax.parallel(stax.serial(*main_layers),
                          MaskedDense(mask_skip, bias=False)), stax.FanInSum)
    else:
        net_init, net = stax.serial(*main_layers)

    def init_fun(rng_key, input_shape):
        """
        :param rng_key: rng_key used to initialize parameters
        :param input_shape: input shape
        """
        assert input_dim == input_shape[-1]
        return net_init(rng_key, input_shape)

    def apply_fun(params, inputs, **kwargs):
        """
        :param params: layer parameters
        :param inputs: layer inputs
        """
        out = net(params, inputs, **kwargs)

        # reshape output as necessary
        out = jnp.reshape(out,
                          inputs.shape[:-1] + (output_multiplier, input_dim))
        # move param dims to the first dimension
        out = jnp.moveaxis(out, -2, 0)

        if all_ones:
            # Squeeze dimension if all parameters are one dimensional
            out = tuple([out[i] for i in range(output_multiplier)])
        else:
            # If not all ones, then probably don't want to squeeze a single dimension parameter
            out = tuple([out[s] for s in param_slices])

        # if len(param_dims) == 1, we return the array instead of a tuple of arrays
        return out[0] if len(param_dims) == 1 else out

    return init_fun, apply_fun
Exemplo n.º 12
0
    print("Starting STAX demo")
    # Stax version
    from jax.experimental import stax

    def const_init(params):
        def init(rng_key, shape):
            return params

        return init

    #net_init, net_apply = stax.serial(stax.Dense(1), stax.elementwise(sigmoid))
    dense_layer = stax.Dense(1,
                             W_init=const_init(np.reshape(w, (D, 1))),
                             b_init=const_init(np.array([0.0])))
    net_init, net_apply = stax.serial(dense_layer)
    rng = jax.random.PRNGKey(0)
    in_shape = (-1, D)
    out_shape, net_params = net_init(rng, in_shape)

    def NLL_model(net_params, net_apply, batch):
        X, y = batch
        logits = net_apply(net_params, X)
        return BCE_with_logits(logits, y)

    y_pred2 = net_apply(net_params, X_test)
    loss2 = NLL_model(net_params, net_apply, (X_test, y_test))
    grad_jax2 = grad(NLL_model)(net_params, net_apply, (X_test, y_test))
    grad_jax3 = grad_jax2[0][0]  # layer 0, block 0 (weights not bias)
    grad_jax4 = grad_jax3[:, 0]  # column vector
    assert np.allclose(grad_np, grad_jax4)
Exemplo n.º 13
0
def conv_net(mode="train"):
    out_dim = 1
    dim_nums = ("NHWC", "HWIO", "NHWC")
    unit_stride = (1, 1)
    zero_pad = ((0, 0), (0, 0))

    # Primary convolutional layer.
    conv_channels = 32
    conv_init, conv_apply = Conv(out_chan=conv_channels,
                                 filter_shape=(3, 3),
                                 strides=(1, 3),
                                 padding=zero_pad)
    # Group all possible pairs.
    pair_channels, filter_shape = 256, (1, 2)

    # Convolutional block with the same number of channels.
    block_channels = pair_channels
    conv_block_init, conv_block_apply = serial(
        Conv(block_channels, (1, 3), unit_stride, "SAME"),
        Relu,  # One block of convolutions.
        Conv(block_channels, (1, 3), unit_stride, "SAME"),
        Relu,
        Conv(block_channels, (1, 3), unit_stride, "SAME"))
    # Forward pass.
    hidden_size = 2048
    dropout_rate = 0.25
    serial_init, serial_apply = serial(
        Conv(block_channels, (1, 3), (1, 3), zero_pad),
        Relu,  # Using convolution with strides
        Flatten,
        Dense(hidden_size),  # instead of pooling for downsampling.
        #    Dropout(dropout_rate, mode),
        Relu,
        Dense(out_dim))

    def init_fun(rng, input_shape):
        rng, conv_rng, block_rng, serial_rng = jax.random.split(rng, num=4)

        # Primary convolutional layer.
        conv_shape, conv_params = conv_init(conv_rng, (-1, ) + input_shape)

        # Grouping all possible pairs.
        kernel_shape = [
            filter_shape[0], filter_shape[1], conv_channels, pair_channels
        ]
        bias_shape = [1, 1, 1, pair_channels]
        W_init = glorot_normal(in_axis=2, out_axis=3)
        b_init = normal(1e-6)
        k1, k2 = jax.random.split(rng)
        W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape)
        pair_shape = conv_shape[:2] + (15, ) + (pair_channels, )
        pair_params = (W, b)

        # Convolutional block.
        conv_block_shape, conv_block_params = conv_block_init(
            block_rng, pair_shape)

        # Forward pass.
        serial_shape, serial_params = serial_init(serial_rng, conv_block_shape)
        params = [conv_params, pair_params, conv_block_params, serial_params]
        return serial_shape, params

    def apply_fun(params, inputs):
        conv_params, pair_params, conv_block_params, serial_params = params

        # Apply the primary convolutional layer.
        conv_out = conv_apply(conv_params, inputs)
        conv_out = relu(conv_out)

        # Group all possible pairs.
        W, b = pair_params
        pair_1 = conv_general_dilated(conv_out, W, unit_stride, zero_pad,
                                      (1, 1), (1, 1), dim_nums) + b
        pair_2 = conv_general_dilated(conv_out, W, unit_stride, zero_pad,
                                      (1, 1), (1, 2), dim_nums) + b
        pair_3 = conv_general_dilated(conv_out, W, unit_stride, zero_pad,
                                      (1, 1), (1, 3), dim_nums) + b
        pair_4 = conv_general_dilated(conv_out, W, unit_stride, zero_pad,
                                      (1, 1), (1, 4), dim_nums) + b
        pair_5 = conv_general_dilated(conv_out, W, unit_stride, zero_pad,
                                      (1, 1), (1, 5), dim_nums) + b
        pair_out = jnp.dstack([pair_1, pair_2, pair_3, pair_4, pair_5])
        pair_out = relu(pair_out)

        # Convolutional block.
        conv_block_out = conv_block_apply(conv_block_params, pair_out)

        # Residual connection.
        res_out = conv_block_out + pair_out
        res_out = relu(res_out)

        # Forward pass.
        out = serial_apply(serial_params, res_out)
        return out

    return init_fun, apply_fun
Exemplo n.º 14
0
def main(_):
    rng = random.PRNGKey(0)

    # Load MNIST dataset
    train_images, train_labels, test_images, test_labels = datasets.mnist()

    batch_size = 128
    batch_shape = (-1, 28, 28, 1)
    num_train = train_images.shape[0]
    num_complete_batches, leftover = divmod(num_train, batch_size)
    num_batches = num_complete_batches + bool(leftover)

    train_images = np.reshape(train_images, batch_shape)
    test_images = np.reshape(test_images, batch_shape)

    def data_stream():
        rng = npr.RandomState(0)
        while True:
            perm = rng.permutation(num_train)
            for i in range(num_batches):
                batch_idx = perm[i * batch_size:(i + 1) * batch_size]
                yield train_images[batch_idx], train_labels[batch_idx]

    batches = data_stream()

    # Model, loss, and accuracy functions
    init_random_params, predict = stax.serial(
        stax.Conv(32, (8, 8), strides=(2, 2), padding="SAME"),
        stax.Relu,
        stax.Conv(128, (6, 6), strides=(2, 2), padding="VALID"),
        stax.Relu,
        stax.Conv(128, (5, 5), strides=(1, 1), padding="VALID"),
        stax.Flatten,
        stax.Dense(128),
        stax.Relu,
        stax.Dense(10),
    )

    def loss(params, batch):
        inputs, targets = batch
        preds = predict(params, inputs)
        return -np.mean(logsoftmax(preds) * targets)

    def accuracy(params, batch):
        inputs, targets = batch
        target_class = np.argmax(targets, axis=1)
        predicted_class = np.argmax(predict(params, inputs), axis=1)
        return np.mean(predicted_class == target_class)

    # Instantiate an optimizer
    opt_init, opt_update, get_params = optimizers.adam(0.001)

    @jit
    def update(i, opt_state, batch):
        params = get_params(opt_state)
        return opt_update(i, grad(loss)(params, batch), opt_state)

    # Initialize model
    _, init_params = init_random_params(rng, batch_shape)
    opt_state = opt_init(init_params)
    itercount = itertools.count()

    # Training loop
    print("\nStarting training...")
    for epoch in range(FLAGS.nb_epochs):
        start_time = time.time()
        for _ in range(num_batches):
            opt_state = update(next(itercount), opt_state, next(batches))
        epoch_time = time.time() - start_time

        # Evaluate model on clean data
        params = get_params(opt_state)
        train_acc = accuracy(params, (train_images, train_labels))
        test_acc = accuracy(params, (test_images, test_labels))

        # Evaluate model on adversarial data
        model_fn = lambda images: predict(params, images)
        test_images_fgm = fast_gradient_method(model_fn, test_images,
                                               FLAGS.eps, np.inf)
        test_images_pgd = projected_gradient_descent(model_fn, test_images,
                                                     FLAGS.eps, 0.01, 40,
                                                     np.inf)
        test_acc_fgm = accuracy(params, (test_images_fgm, test_labels))
        test_acc_pgd = accuracy(params, (test_images_pgd, test_labels))

        print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
        print("Training set accuracy: {}".format(train_acc))
        print("Test set accuracy on clean examples: {}".format(test_acc))
        print("Test set accuracy on FGM adversarial examples: {}".format(
            test_acc_fgm))
        print("Test set accuracy on PGD adversarial examples: {}".format(
            test_acc_pgd))
Exemplo n.º 15
0
d_hidden_size = hidden_size
g_out_dim = 2
d_out_dim = 1
z_size = 64
smooth = 0.0
learning_rate = 1e-4
batch_size = 256
epochs = 10001
nsave = 2000

#Define Generator
gen_init, gen_apply = stax.serial(
    Dense(g_hidden_size), Relu,
    Dense(g_hidden_size), Relu,
    Dense(g_hidden_size), Relu,
    Dense(g_hidden_size), Relu,
    Dense(g_hidden_size), Relu,
    Dense(g_hidden_size), Relu,
    Dense(g_out_dim)
)
g_in_shape = (-1, z_size)
_, gen_params = gen_init(rng, g_in_shape)

#Define Discriminator
disc_init, disc_apply = stax.serial(
    Dense(d_hidden_size), Relu,
    Dense(d_hidden_size), Relu,
    Dense(d_hidden_size), Relu,
    Dense(d_hidden_size), Relu,
    Dense(d_hidden_size), Relu,
    Dense(d_hidden_size), Relu,
Exemplo n.º 16
0
 def action_encoder(output_num):
     return serial(
         Dense(128),
         Tanh,  # BatchNormつけるとなぜか出力が固定値になる,
         Dense(output_num))
Exemplo n.º 17
0
def logistic_regression(rng, dim):
    """Logistic regression."""
    init_params, forward = stax.serial(Dense(1), Sigmoid)
    temp, rng = random.split(rng)
    params = init_params(temp, (-1, dim))[1]
    return params, forward
Exemplo n.º 18
0
 def value_decoder(output_num):
     return serial(
         Dense(128),
         Tanh,  # BatchNormつけるとなぜか出力が固定値になる
         Dense(output_num),
     )
Exemplo n.º 19
0
    set_step_size(lr_i)
    return opt_update(i, g, opt_state)


key = random.PRNGKey(3)
num_epochs = 60000
num_instances, num_vars = 200, 2
batch_size = num_instances
minim, maxim = -5, 5

x, y = generate_data(num_instances, 1, key)
X = np.c_[np.ones_like(x), x]
batches = data_stream(num_instances, batch_size)

init_random_params, predict = stax.serial(Dense(5), Softplus, Dense(5),
                                          Softplus, Dense(5), Softplus,
                                          Dense(5), Softplus, Dense(1))

lambd, step_size = 0.6, 1e-4
opt_init, opt_update, get_params, soft_thresholding, set_step_size = pgd(
    step_size, lambd)
_, init_params = init_random_params(key, (-1, num_vars))
opt_state = opt_init(init_params)
itercount = itertools.count()

for epoch in range(num_epochs):
    opt_state = update(next(itercount), opt_state, next(batches))

labels = {"training": "Data", "test": "Deep Neural Net"}
x_test = np.arange(minim, maxim, 1e-5)
x_test = np.c_[np.ones((x_test.shape[0], 1)), x_test]
Exemplo n.º 20
0
def loss(params, batch):
    inputs, targets = batch
    preds = predict(params, inputs)
    return -np.mean(np.sum(preds * targets, axis=1))


def accuracy(params, batch):
    inputs, targets = batch
    target_class = np.argmax(targets, axis=1)
    predicted_class = np.argmax(predict(params, inputs), axis=1)
    return np.mean(predicted_class == target_class)


init_random_params, predict = stax.serial(
    Dense(300), Relu,
    Dense(10), LogSoftmax)

if __name__ == "__main__":
    rng = random.PRNGKey(0)
    
    step_size = 0.001
    num_epochs = 10
    batch_size = 128
    momentum_mass = 0.9
    
    # training/test split
    train_images, train_labels, test_images, test_labels = datasets.mnist()
    num_train = train_images.shape[0]
    
    # converting to batches
Exemplo n.º 21
0
def loss(params, batch):
    inputs, targets = batch
    preds = predict(params, inputs)
    return -jnp.mean(jnp.sum(preds * targets, axis=1))


def accuracy(params, batch):
    inputs, targets = batch
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(predict(params, inputs), axis=1)
    return jnp.mean(predicted_class == target_class)


init_random_params, predict = stax.serial(Dense(config.num_hidden), Relu,
                                          Dense(config.num_hidden), Relu,
                                          Dense(10), LogSoftmax)

if __name__ == "__main__":
    rng = random.PRNGKey(0)

    num_epochs = 10

    train_images, train_labels, test_images, test_labels = datasets.mnist()
    num_train = train_images.shape[0]
    num_complete_batches, leftover = divmod(num_train, config.batch_size)
    num_batches = num_complete_batches + bool(leftover)

    def data_stream():
        rng = npr.RandomState(0)
        while True:
Exemplo n.º 22
0
 def make_main(input_shape):
     # the number of output channels depends on the number of input channels
     return stax.serial(Conv(filters1, (1, 1)), BatchNorm(), Relu,
                        Conv(filters2, (ks, ks), padding='SAME'),
                        BatchNorm(), Relu, Conv(input_shape[3], (1, 1)),
                        BatchNorm())
Exemplo n.º 23
0
  plt.figure()
  plt.scatter(population_samples[:, 0], population_samples[:, 1])
  plt.title("Population")

  plt.figure()
  plt.scatter(biased_samples[:, 0], biased_samples[:, 1])
  plt.title("Biased sample")
  plt.show()

encoder_init, encoder = stax.serial(
    Dense(32),
    Relu,
    FanOut(2),
    stax.parallel(
        Dense(latent_dim),
        stax.serial(
            Dense(latent_dim),
            Softplus,
            Dampen(0.1, 1e-6),
        ),
    ),
    DistributionLayer(DiagMVN),
)

decoder_init, decoder = stax.serial(
    Dense(128),
    Relu,
    Dense(128),
    Relu,
    Dense(128),
    Relu,
    Dense(128),
Exemplo n.º 24
0
def wide_resnet_group(n, num_channels, strides=(1, 1)):
    blocks = [wide_resnet_block(num_channels, strides, channel_mismatch=True)]
    for _ in range(1, n):
        blocks += [wide_resnet_block(num_channels, strides=(1, 1))]
    return stax.serial(*blocks)
Exemplo n.º 25
0
def GCNPredicator(hidden_feats,
                  activation=None,
                  batchnorm=None,
                  dropout=None,
                  pooling_method='mean',
                  predicator_hidden_feats=64,
                  predicator_dropout=None,
                  n_out=1,
                  bias=True,
                  normalize=True):
    r"""GCN Predicator is a wrapper function using GCN and MLP.

    Parameters
    ----------
    hidden_feats : list[int]
        List of output node features.
    activation : list[Function] or None
        ``activation[i]`` is the activation function of the i-th GCN layer.
    batchnorm : list[bool] or None
        ``batchnorm[i]`` decides if batch normalization is to be applied on the output of
        the i-th GCN layer. ``len(batchnorm)`` equals the number of GCN layers. By default,
        batch normalization is applied for all GCN layers.
    dropout : list[float] or None
        ``dropout[i]`` decides the dropout probability on the output of the i-th GCN layer.
        ``len(dropout)`` equals the number of GCN layers. By default, no dropout is
        performed for all layers.
    pooling_method : str ('max', 'min', 'mean', 'sum')
        pooling method name
    predicator_hidden_feats : int
        Size of hidden graph representations in the predicator, default to 128.
    predicator_dropout : float or None
        The probability for dropout in the predicator, default to None.
    n_out : int
        Number of the output size, default to 1.
    bias : bool
        Whether to add bias after affine transformation, default to be True.
    normalize : bool
        Whether to normalize the adjacency matrix or not, default to be True.

    Returns
    -------
    init_fun : Function
        Initializes the parameters of the layer.
    apply_fun : Function
        Defines the forward computation function.
    """
    gcn_init, gcn_fun = GCN(hidden_feats, activation, batchnorm, dropout, bias,
                            normalize)
    pooling_fun = pooling(method=pooling_method)
    predicator_dropout = 0.0 if predicator_dropout is None else predicator_dropout
    _, drop_fun = Dropout(predicator_dropout)
    dnn_layers = [Dense(predicator_hidden_feats), Relu, Dense(n_out)]
    dnn_init, dnn_fun = serial(*dnn_layers)

    def init_fun(rng, input_shape):
        """Initialize parameters.

        Parameters
        ----------
        rng : PRNGKey
            rng is a value for generating random values.
        input_shape :  (batch_size, N, M1)
            The shape of input (input node features).
            N is the total number of nodes in the batch of graphs.
            M1 is the input node feature size.

        Returns
        -------
        output_shape : (batch_size, M4)
            The shape of output.
            M4 is the output size of GCNPredicator and equal to n_out.
        params: Tuple (gcn_param, dnn_param)
            gcn_param is all parameters of GCN.
            dnn_param is all parameters of full connected layer.
        """
        output_shape = input_shape
        rng, gcn_rng, dnn_rng = random.split(rng, 3)
        output_shape, gcn_param = gcn_init(gcn_rng, output_shape)
        # convert out_shape by pooling
        output_shape = (output_shape[0], output_shape[-1])
        output_shape, dnn_param = dnn_init(dnn_rng, output_shape)
        return output_shape, (gcn_param, dnn_param)

    def apply_fun(params, node_feats, adj, rng, is_train):
        """Define forward computation function.

        Parameters
        ----------
        node_feats : ndarray of shape (batch_size, N, M1)
            Batched input node features.
            N is the total number of nodes in the batch of graphs.
            M1 is the input node feature size.
        adj : ndarray of shape (batch_size, N, N)
            Batched adjacency matrix.
        rng : PRNGKey
            rng is a value for generating random values
        is_train : bool
            Whether the model is training or not.

        Returns
        -------
        out : ndarray of shape (batch_size, M4)
            The shape of output.
            M4 is the output size of GCNPredicator and equal to n_out.
        """
        gcn_param, dnn_param = params
        rng, gcn_rng, dropout_rng = random.split(rng, 3)
        node_feats = gcn_fun(gcn_param, node_feats, adj, gcn_rng, is_train)
        # pooling
        graph_feat = pooling_fun(node_feats)
        if predicator_dropout != 0.0:
            graph_feat = drop_fun(None, graph_feat, is_train, rng=dropout_rng)
        out = dnn_fun(dnn_param, graph_feat)
        return out

    return init_fun, apply_fun
Exemplo n.º 26
0
        f = partial(
            apply_fun_scan, params
        )  # We use partial to “clone” all the params to use at all timesteps.
        _, out = lax.scan(f, h, inputs)
        return out

    return init_fun, apply_fun


num_dims = 10  #0              # Number of OU timesteps
batch_size = 64  # Batchsize
num_hidden_units = 12  # GRU cells in the RNN layer

# Initialize the network and perform a forward pass
init_fun, gru_rnn = stax.serial(
    Dense(num_hidden_units), Relu, GRU(num_hidden_units), Dense(1)
)  #<-this Dense(1) is applied to every lax.scan output from the GRU loop??? hence shape of pred
_, params = init_fun(key, (batch_size, num_dims, 1))


def mse_loss(params, inputs, targets):
    """ Calculate the Mean Squared Error Prediction Loss. """
    preds = gru_rnn(params, inputs)
    return np.mean((preds - targets)**2)


@jit
def update(params, x, y, opt_state):
    """ Perform a forward pass, calculate the MSE & perform a SGD step. """
    loss, grads = value_and_grad(mse_loss)(params, x, y)
    opt_state = opt_update(0, grads, opt_state)
Exemplo n.º 27
0
  preds = predict(params, inputs)
  return -jnp.mean(jnp.sum(preds * targets, axis=1))

def accuracy(params, batch):
  inputs, targets = batch
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(predict(params, inputs), axis=1)
  return jnp.mean(predicted_class == target_class)

init_random_params, predict = stax.serial(
    Dense(1024),
    Tanh,
    Dense(1024),
    Tanh,
    # Dense(1024), Tanh,
    # Dense(1024), Tanh,
    # Dense(1024), Tanh,
    # Dense(1024), Tanh,
    # Dense(1024), Tanh,
    # Dense(1024), Tanh,
    Dense(10),
    LogSoftmax)

if __name__ == "__main__":
  wandb.init(project="playing-the-lottery", entity="skainswo")

  rp = RngPooper(random.PRNGKey(0))

  config = wandb.config
  config.learning_rate = 0.001
  config.num_epochs = 100
Exemplo n.º 28
0
    return image_grid(nrow, ncol, sampled_images, (28, 28))


def image_grid(nrow, ncol, imagevecs, imshape):
    """Reshape a stack of image vectors into an image grid for plotting."""
    images = iter(imagevecs.reshape((-1, ) + imshape))
    return jnp.vstack([
        jnp.hstack([next(images).T for _ in range(ncol)][::-1])
        for _ in range(nrow)
    ]).T


encoder_init, encode = stax.serial(
    Dense(512),
    Relu,
    Dense(512),
    Relu,
    FanOut(2),
    stax.parallel(Dense(10), stax.serial(Dense(10), Softplus)),
)

decoder_init, decode = stax.serial(
    Dense(512),
    Relu,
    Dense(512),
    Relu,
    Dense(28 * 28),
)

if __name__ == "__main__":
    step_size = 0.001
    num_epochs = 100
Exemplo n.º 29
0
                   'Ratio of the standard deviation to the clipping norm')
flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
flags.DEFINE_integer('batch_size', 256, 'Batch size')
flags.DEFINE_integer('epochs', 60, 'Number of epochs')
flags.DEFINE_integer('seed', 0, 'Seed for jax PRNG')
flags.DEFINE_integer(
    'microbatches', None, 'Number of microbatches '
    '(must evenly divide batch_size)')
flags.DEFINE_string('model_dir', None, 'Model directory')

init_random_params, predict = stax.serial(
    stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)),
    stax.Relu,
    stax.MaxPool((2, 2), (1, 1)),
    stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)),
    stax.Relu,
    stax.MaxPool((2, 2), (1, 1)),
    stax.Flatten,
    stax.Dense(32),
    stax.Relu,
    stax.Dense(10),
)


def loss(params, batch):
    inputs, targets = batch
    logits = predict(params, inputs)
    logits = stax.logsoftmax(logits)  # log normalize
    return -jnp.mean(jnp.sum(logits * targets, axis=1))  # cross entropy loss


def accuracy(params, batch):
Exemplo n.º 30
0
flags.DEFINE_string('clustering', 'KMeans',
                    'clustering method for projected embeddings')
flags.DEFINE_integer('ppc', 10, 'number of examples picked per cluster')
flags.DEFINE_integer('n_extra', 3000, 'number of extra points')
flags.DEFINE_integer('uncertain_extra', 1000, 'n_uncertain - n_extra')
flags.DEFINE_bool(
    'show_label', True,
    'visualize predicted label at top/left, true at bottom/right')

# BEGIN: define the classifier model
init_fn_0, apply_fn_0 = stax.serial(
    stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)),
    stax.Tanh,
    stax.MaxPool((2, 2), (1, 1)),
    stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)),
    stax.Tanh,
    stax.MaxPool((2, 2), (1, 1)),
    stax.Flatten,  # (-1, 800)
    stax.Dense(32),
    stax.Tanh,  # embeddings
)

init_fn_1, apply_fn_1 = stax.serial(
    stax.Dense(10),  # logits
)


def predict(params, inputs):
    params_0 = params[:-1]
    params_1 = params[-1:]
    embeddings = apply_fn_0(params_0, inputs)